/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.component.srl;

import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.prediction.StringPrediction;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.classification.vector.StringFeatureVector;
import com.googlecode.clearnlp.component.AbstractStatisticalComponent;
import com.googlecode.clearnlp.dependency.DEPArc;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.feature.xml.FtrToken;
import com.googlecode.clearnlp.feature.xml.JointFtrXml;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTOutput;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

public class CRolesetClassifier
extends AbstractStatisticalComponent {
    private final String ENTRY_CONFIGURATION = "role_CONFIGURATION";
    private final String ENTRY_FEATURE = "role_FEATURE";
    private final String ENTRY_LEXICA = "role_LEXICA";
    private final String ENTRY_MODEL = "role_MODEL";
    protected final int LEXICA_ROLESETS = 0;
    protected final int LEXICA_LEMMAS = 1;
    protected Map<String, Set<String>> m_collect;
    protected Map<String, String> m_rolesets;
    protected ObjectIntOpenHashMap<String> m_lemmas;
    protected String[] g_rolesets;
    protected int i_pred;

    public CRolesetClassifier(JointFtrXml[] xmls) {
        super(xmls);
        this.m_collect = new HashMap<String, Set<String>>();
    }

    public CRolesetClassifier(JointFtrXml[] xmls, StringTrainSpace[] spaces, Object[] lexica) {
        super(xmls, spaces, lexica);
    }

    public CRolesetClassifier(JointFtrXml[] xmls, StringModel[] models, Object[] lexica) {
        super(xmls, models, lexica);
    }

    public CRolesetClassifier(ZipInputStream in) {
        super(in);
    }

    @Override
    protected void initLexia(Object[] lexica) {
        this.m_rolesets = (Map)lexica[0];
        this.m_lemmas = (ObjectIntOpenHashMap)lexica[1];
    }

    @Override
    public void loadModels(ZipInputStream zin) {
        int fLen = "role_FEATURE".length();
        int mLen = "role_MODEL".length();
        this.f_xmls = new JointFtrXml[1];
        this.s_models = null;
        try {
            ZipEntry zEntry;
            while ((zEntry = zin.getNextEntry()) != null) {
                String entry = zEntry.getName();
                if (entry.equals("role_CONFIGURATION")) {
                    this.loadDefaultConfiguration(zin);
                    continue;
                }
                if (entry.startsWith("role_FEATURE")) {
                    this.loadFeatureTemplates(zin, Integer.parseInt(entry.substring(fLen)));
                    continue;
                }
                if (entry.startsWith("role_MODEL")) {
                    this.loadStatisticalModels(zin, Integer.parseInt(entry.substring(mLen)));
                    continue;
                }
                if (!entry.equals("role_LEXICA")) continue;
                this.loadLexica(zin);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void loadLexica(ZipInputStream zin) throws Exception {
        BufferedReader fin = new BufferedReader(new InputStreamReader(zin));
        System.out.println("Loading lexica.");
        this.m_rolesets = UTInput.getStringMap(fin, " ");
        this.m_lemmas = UTInput.getStringIntOpenHashMap(fin, " ");
    }

    @Override
    public void saveModels(ZipOutputStream zout) {
        try {
            this.saveDefaultConfiguration(zout, "role_CONFIGURATION");
            this.saveFeatureTemplates(zout, "role_FEATURE");
            this.saveLexica(zout);
            this.saveStatisticalModels(zout, "role_MODEL");
            zout.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void saveLexica(ZipOutputStream zout) throws Exception {
        zout.putNextEntry(new ZipEntry("role_LEXICA"));
        PrintStream fout = UTOutput.createPrintBufferedStream(zout);
        System.out.println("Saving lexica.");
        UTOutput.printMap(fout, this.m_rolesets, " ");
        fout.flush();
        UTOutput.printMap(fout, this.m_lemmas, " ");
        fout.flush();
        zout.closeEntry();
    }

    @Override
    public Object[] getLexica() {
        Map<String, String> mRolesets = this.getRolesetMap();
        Object[] lexica = new Object[]{mRolesets, this.getLemmas(this.m_collect.keySet(), mRolesets)};
        return lexica;
    }

    private Map<String, String> getRolesetMap() {
        HashMap<String, String> map = new HashMap<String, String>();
        for (String lemma : this.m_collect.keySet()) {
            Set<String> set = this.m_collect.get(lemma);
            if (set.size() != 1) continue;
            map.put(lemma, new ArrayList<String>(set).get(0));
        }
        return map;
    }

    private ObjectIntOpenHashMap<String> getLemmas(Set<String> sLemmas, Map<String, String> mRolesets) {
        ObjectIntOpenHashMap map = new ObjectIntOpenHashMap();
        int idx = 0;
        for (String lemma : sLemmas) {
            if (mRolesets.containsKey(lemma)) continue;
            map.put((Object)lemma, idx++);
        }
        return map;
    }

    @Override
    public void countAccuracy(int[] counts) {
        int correct = 0;
        int total = 0;
        for (int i = 1; i < this.t_size; ++i) {
            DEPNode node = this.d_tree.get(i);
            String gRoleset = this.g_rolesets[i];
            if (gRoleset == null) continue;
            ++total;
            if (!gRoleset.equals(node.getFeat("pb"))) continue;
            ++correct;
        }
        counts[0] = counts[0] + total;
        counts[1] = counts[1] + correct;
    }

    @Override
    public void process(DEPTree tree) {
        this.init(tree);
        this.processAux();
    }

    protected void init(DEPTree tree) {
        this.d_tree = tree;
        this.t_size = tree.size();
        if (this.i_flag != 2) {
            this.g_rolesets = this.d_tree.getRolesetIDs();
        }
        tree.setDependents();
    }

    protected void processAux() {
        if (this.i_flag == 0) {
            this.addLexica();
        } else {
            this.classify();
        }
    }

    protected void addLexica() {
        this.i_pred = 1;
        while (this.i_pred < this.t_size) {
            String roleset = this.g_rolesets[this.i_pred];
            String lemma = this.d_tree.get((int)this.i_pred).lemma;
            if (roleset != null) {
                Set<String> set = this.m_collect.get(lemma);
                if (set == null) {
                    set = new HashSet<String>();
                    this.m_collect.put(lemma, set);
                }
                set.add(roleset);
            }
            ++this.i_pred;
        }
    }

    protected void classify() {
        this.i_pred = 1;
        while (this.i_pred < this.t_size) {
            DEPNode pred = this.d_tree.get(this.i_pred);
            if (pred.getFeat("pb") != null) {
                String roleset = this.m_rolesets.get(pred.lemma);
                if (roleset == null) {
                    roleset = this.m_lemmas.containsKey((Object)pred.lemma) ? this.getLabel(this.m_lemmas.get((Object)pred.lemma)) : pred.lemma + ".01";
                }
                pred.addFeat("pb", roleset);
            }
            ++this.i_pred;
        }
    }

    protected String getLabel(int modelId) {
        StringFeatureVector vector = this.getFeatureVector(this.f_xmls[0]);
        String label = null;
        if (this.i_flag == 1) {
            label = this.getGoldLabel();
            this.s_spaces[modelId].addInstance(label, vector);
        } else if (this.i_flag == 2 || this.i_flag == 4) {
            label = this.getAutoLabel(vector, modelId);
        }
        return label;
    }

    private String getGoldLabel() {
        return this.g_rolesets[this.i_pred];
    }

    private String getAutoLabel(StringFeatureVector vector, int modelId) {
        StringPrediction p = this.s_models[modelId].predictBest(vector);
        return p.label;
    }

    @Override
    protected String getField(FtrToken token) {
        DEPNode node = this.getNode(token);
        if (node == null) {
            return null;
        }
        if (token.isField("f")) {
            return node.form;
        }
        if (token.isField("m")) {
            return node.lemma;
        }
        if (token.isField("p")) {
            return node.pos;
        }
        if (token.isField("d")) {
            return node.getLabel();
        }
        Matcher m = JointFtrXml.P_FEAT.matcher(token.field);
        if (m.find()) {
            return node.getFeat(m.group(1));
        }
        return null;
    }

    @Override
    protected String[] getFields(FtrToken token) {
        DEPNode node = this.getNode(token);
        if (node == null) {
            return null;
        }
        if (token.isField("ds")) {
            return this.getDeprelSet(node.getDependents());
        }
        return null;
    }

    private String[] getDeprelSet(List<DEPArc> deps) {
        if (deps.isEmpty()) {
            return null;
        }
        HashSet<String> set = new HashSet<String>();
        for (DEPArc arc : deps) {
            set.add(arc.getLabel());
        }
        String[] fields = new String[set.size()];
        set.toArray(fields);
        return fields;
    }

    private DEPNode getNode(FtrToken token) {
        DEPNode node = this.getNodeAux(token);
        if (node == null) {
            return null;
        }
        if (token.relation != null) {
            if (token.isRelation("h")) {
                node = node.getHead();
            } else if (token.isRelation("lmd")) {
                node = node.getLeftMostDependent();
            } else if (token.isRelation("rmd")) {
                node = node.getRightMostDependent();
            } else if (token.isRelation("lnd")) {
                node = node.getLeftNearestDependent();
            } else if (token.isRelation("rnd")) {
                node = node.getRightNearestDependent();
            }
        }
        return node;
    }

    private DEPNode getNodeAux(FtrToken token) {
        if (token.offset == 0) {
            return this.d_tree.get(this.i_pred);
        }
        int cIndex = this.i_pred + token.offset;
        if (0 < cIndex && cIndex < this.d_tree.size()) {
            return this.d_tree.get(cIndex);
        }
        return null;
    }
}

