/*
 * Decompiled with CFR 0.152.
 */
package com.googlecode.clearnlp.nlp;

import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.component.AbstractStatisticalComponent;
import com.googlecode.clearnlp.component.dep.CDEPBackParser;
import com.googlecode.clearnlp.component.dep.CDEPPassParser;
import com.googlecode.clearnlp.component.pos.CPOSBackTagger;
import com.googlecode.clearnlp.component.pos.CPOSTagger;
import com.googlecode.clearnlp.component.srl.CPredIdentifier;
import com.googlecode.clearnlp.component.srl.CRolesetClassifier;
import com.googlecode.clearnlp.component.srl.CSRLabeler;
import com.googlecode.clearnlp.component.srl.CSenseClassifier;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.engine.EngineProcess;
import com.googlecode.clearnlp.feature.xml.JointFtrXml;
import com.googlecode.clearnlp.nlp.AbstractNLP;
import com.googlecode.clearnlp.reader.JointReader;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.zip.ZipOutputStream;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

public class NLPTrain
extends AbstractNLP {
    protected final String DELIM_FILES = ":";
    @Option(name="-c", usage="configuration file (required)", required=true, metaVar="<filename>")
    protected String s_configFile;
    @Option(name="-f", usage="feature template files delimited by ':' (required)", required=true, metaVar="<filename>")
    protected String s_featureFiles;
    @Option(name="-i", usage="input directory containing training files (required)", required=true, metaVar="<directory>")
    protected String s_trainDir;
    @Option(name="-m", usage="model file (output; required)", required=true, metaVar="<filename>")
    protected String s_modelFile;
    @Option(name="-n", usage="bootstrapping level (default: 2)", required=false, metaVar="<integer>")
    protected int n_boot = 0;
    @Option(name="-z", usage="mode (pos|morph|dep|pred|role|srl)", required=true, metaVar="<string>")
    protected String s_mode;
    @Option(name="-margin", usage="margin between the 1st and 2nd predictions (default: 0.5)", required=false, metaVar="<double>")
    protected double d_margin = 0.5;
    @Option(name="-beams", usage="the size of beam (default: 0)", required=false, metaVar="<double>")
    protected int n_beams = 0;

    public NLPTrain() {
    }

    public NLPTrain(String[] args) {
        this.initArgs(args);
        try {
            this.train(this.s_configFile, this.s_featureFiles.split(":"), this.s_trainDir, this.s_modelFile, this.s_mode);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void train(String configFile, String[] featureFiles, String trainDir, String modelFile, String mode) throws Exception {
        Element eConfig = UTXml.getDocumentElement(new FileInputStream(configFile));
        JointFtrXml[] xmls = this.getFeatureTemplates(featureFiles);
        String[] trainFiles = UTFile.getSortedFileListBySize(trainDir, ".*", true);
        JointReader reader = this.getJointReader(UTXml.getFirstElementByTagName(eConfig, "reader"));
        AbstractStatisticalComponent component = this.getComponent(eConfig, reader, xmls, trainFiles, -1, mode);
        component.saveModels(new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(modelFile))));
    }

    protected AbstractStatisticalComponent getComponent(Element eConfig, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, int devId, String mode) {
        if (mode.equals("pos")) {
            return this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CPOSTagger(xmls, this.getLowerSimplifiedForms(reader, xmls[0], trainFiles, devId)), mode, devId);
        }
        if (mode.equals("dep")) {
            return this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CDEPPassParser(xmls), mode, devId);
        }
        if (mode.equals("pred")) {
            return this.getTrainedComponent(eConfig, xmls, trainFiles, null, null, mode, 0, devId);
        }
        if (mode.equals("role")) {
            return this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CRolesetClassifier(xmls), mode, devId);
        }
        if (mode.startsWith("sense")) {
            return this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CSenseClassifier(xmls, mode.substring(mode.lastIndexOf("_") + 1)), mode, devId);
        }
        if (mode.equals("srl")) {
            return this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CSRLabeler(xmls), mode, devId);
        }
        if (mode.equals("pos_back")) {
            return this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CPOSTagger(xmls, this.getLowerSimplifiedForms(reader, xmls[0], trainFiles, devId)), mode, devId);
        }
        if (mode.equals("dep_back")) {
            return this.getTrainedComponent(eConfig, reader, xmls, trainFiles, new CDEPBackParser(xmls), mode, devId);
        }
        throw new IllegalArgumentException("The requested mode '" + mode + "' is not supported.");
    }

    protected AbstractStatisticalComponent getComponent(JointFtrXml[] xmls, StringModel[] models, Object[] lexica, String mode) {
        if (mode.equals("pos")) {
            return new CPOSTagger(xmls, models, lexica);
        }
        if (mode.equals("dep")) {
            return new CDEPPassParser(xmls, models, lexica);
        }
        if (mode.equals("pred")) {
            return new CPredIdentifier(xmls, models, lexica);
        }
        if (mode.equals("role")) {
            return new CRolesetClassifier(xmls, models, lexica);
        }
        if (mode.startsWith("sense")) {
            return new CSenseClassifier(xmls, models, lexica, mode.substring(mode.lastIndexOf("_") + 1));
        }
        if (mode.equals("srl")) {
            return new CSRLabeler(xmls, models, lexica);
        }
        if (mode.equals("pos_back")) {
            return new CPOSBackTagger(xmls, models, lexica, this.d_margin);
        }
        if (mode.equals("dep_back")) {
            return new CDEPBackParser(xmls, models, lexica, this.d_margin, this.n_beams);
        }
        throw new IllegalArgumentException("The requested mode '" + mode + "' is not supported.");
    }

    protected AbstractStatisticalComponent getTrainedComponent(Element eConfig, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, AbstractStatisticalComponent component, String mode, int devId) {
        Object[] lexica = this.getLexica(component, reader, xmls, trainFiles, devId);
        AbstractStatisticalComponent processor = null;
        StringModel[] models = null;
        for (int boot = 0; boot <= this.n_boot; ++boot) {
            processor = this.getTrainedComponent(eConfig, xmls, trainFiles, models, lexica, mode, boot, devId);
            models = processor.getModels();
        }
        return processor;
    }

    protected JointFtrXml[] getFeatureTemplates(String[] featureFiles) throws Exception {
        int size = featureFiles.length;
        JointFtrXml[] xmls = new JointFtrXml[size];
        for (int i = 0; i < size; ++i) {
            xmls[i] = new JointFtrXml(new FileInputStream(featureFiles[i]));
        }
        return xmls;
    }

    protected Object[] getLexica(AbstractStatisticalComponent component, JointReader reader, JointFtrXml[] xmls, String[] trainFiles, int devId) {
        int size = trainFiles.length;
        System.out.println("Collecting lexica:");
        for (int i = 0; i < size; ++i) {
            DEPTree tree;
            if (devId == i) continue;
            reader.open(UTInput.createBufferedFileReader(trainFiles[i]));
            while ((tree = reader.next()) != null) {
                component.process(tree);
            }
            reader.close();
            System.out.print(".");
        }
        System.out.println();
        return component.getLexica();
    }

    protected Set<String> getLowerSimplifiedForms(JointReader reader, JointFtrXml xml, String[] trainFiles, int devId) {
        HashSet<String> set = new HashSet<String>();
        int size = trainFiles.length;
        Prob1DMap map = new Prob1DMap();
        System.out.println("Collecting word-forms:");
        for (int i = 0; i < size; ++i) {
            DEPTree tree;
            if (devId == i) continue;
            reader.open(UTInput.createBufferedFileReader(trainFiles[i]));
            set.clear();
            while ((tree = reader.next()) != null) {
                EngineProcess.normalizeForms(tree);
                int len = tree.size();
                for (int j = 1; j < len; ++j) {
                    set.add(tree.get((int)j).lowerSimplifiedForm);
                }
            }
            reader.close();
            map.addAll(set);
            System.out.print(".");
        }
        System.out.println();
        return map.toSet(xml.getDocumentFrequencyCutoff());
    }

    protected AbstractStatisticalComponent getTrainedComponent(Element eConfig, JointFtrXml[] xmls, String[] trainFiles, StringModel[] models, Object[] lexica, String mode, int boot, int devId) {
        StringTrainSpace[] spaces = this.getStringTrainSpaces(eConfig, xmls, trainFiles, models, lexica, mode, boot, devId);
        Element eTrain = UTXml.getFirstElementByTagName(eConfig, mode);
        int mSize = spaces.length;
        models = new StringModel[mSize];
        for (int i = 0; i < mSize; ++i) {
            models[i] = mode.equals("role") || mode.startsWith("sense") ? (StringModel)this.getModel(eTrain, spaces[i], 0, boot) : (StringModel)this.getModel(eTrain, spaces[i], i, boot);
            spaces[i].clear();
        }
        return this.getComponent(xmls, models, lexica, mode);
    }

    protected StringTrainSpace[] getStringTrainSpaces(Element eConfig, JointFtrXml[] xmls, String[] trainFiles, StringModel[] models, Object[] lexica, String mode, int boot, int devId) {
        StringTrainSpace[] spaces;
        int i;
        Element eTrain = UTXml.getFirstElementByTagName(eConfig, mode);
        int mSize = 1;
        int size = trainFiles.length;
        int numThreads = this.getNumOfThreads(eTrain);
        ArrayList<StringTrainSpace[]> lSpaces = new ArrayList<StringTrainSpace[]>();
        ExecutorService executor = Executors.newFixedThreadPool(numThreads);
        System.out.println("Collecting training instances:");
        for (i = 0; i < size; ++i) {
            if (devId == i) continue;
            spaces = this.getStringTrainSpaces(xmls, lexica, mode, boot);
            lSpaces.add(spaces);
            executor.execute(new TrainTask(eConfig, trainFiles[i], this.getComponent(xmls, spaces, models, lexica, mode)));
        }
        executor.shutdown();
        try {
            executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println();
        mSize = ((StringTrainSpace[])lSpaces.get(0)).length;
        spaces = new StringTrainSpace[mSize];
        for (i = 0; i < mSize; ++i) {
            spaces[i] = ((StringTrainSpace[])lSpaces.get(0))[i];
            size = lSpaces.size();
            if (size <= 1) continue;
            System.out.println("Merging training instances:");
            for (int j = 1; j < size; ++j) {
                StringTrainSpace sp = ((StringTrainSpace[])lSpaces.get(j))[i];
                spaces[i].appendSpace(sp);
                sp.clear();
                System.out.print(".");
            }
            System.out.println();
        }
        return spaces;
    }

    protected AbstractStatisticalComponent getComponent(JointFtrXml[] xmls, StringTrainSpace[] spaces, StringModel[] models, Object[] lexica, String mode) {
        if (mode.equals("pos")) {
            return new CPOSTagger(xmls, spaces, lexica);
        }
        if (mode.equals("dep")) {
            return models == null ? new CDEPPassParser(xmls, spaces, lexica) : new CDEPPassParser(xmls, spaces, models, lexica);
        }
        if (mode.equals("pred")) {
            return new CPredIdentifier(xmls, spaces, lexica);
        }
        if (mode.equals("role")) {
            return new CRolesetClassifier(xmls, spaces, lexica);
        }
        if (mode.startsWith("sense")) {
            return new CSenseClassifier(xmls, spaces, lexica, mode.substring(mode.lastIndexOf("_") + 1));
        }
        if (mode.equals("srl")) {
            return models == null ? new CSRLabeler(xmls, spaces, lexica) : new CSRLabeler(xmls, spaces, models, lexica);
        }
        if (mode.equals("pos_back")) {
            return models == null ? new CPOSBackTagger(xmls, spaces, lexica, this.d_margin) : new CPOSBackTagger(xmls, spaces, models, lexica, this.d_margin);
        }
        if (mode.equals("dep_back")) {
            return models == null ? new CDEPBackParser(xmls, spaces, lexica, this.d_margin, this.n_beams) : new CDEPBackParser(xmls, spaces, models, lexica, this.d_margin, this.n_beams);
        }
        throw new IllegalArgumentException("The requested mode '" + mode + "' is not supported.");
    }

    protected StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] xmls, Object[] lexica, String mode, int boot) {
        if (mode.equals("role") || mode.startsWith("sense")) {
            return this.getStringTrainSpaces(xmls[0], ((ObjectIntOpenHashMap)lexica[1]).size());
        }
        if (mode.equals("srl")) {
            return this.getStringTrainSpaces(xmls[0], 2);
        }
        if (boot > 0 && mode.equals("dep_back")) {
            return this.getStringTrainSpaces(xmls, 1);
        }
        return this.getStringTrainSpaces(xmls);
    }

    private StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] xmls) {
        return this.getStringTrainSpaces(xmls, 0);
    }

    private StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] xmls, int cIndex) {
        int size = xmls.length;
        StringTrainSpace[] spaces = new StringTrainSpace[size];
        for (int i = 0; i < size; ++i) {
            spaces[i] = new StringTrainSpace(false, xmls[i].getLabelCutoff(cIndex), xmls[i].getFeatureCutoff(cIndex));
        }
        return spaces;
    }

    private StringTrainSpace[] getStringTrainSpaces(JointFtrXml xml, int size) {
        StringTrainSpace[] spaces = new StringTrainSpace[size];
        for (int i = 0; i < size; ++i) {
            spaces[i] = new StringTrainSpace(false, xml.getLabelCutoff(0), xml.getFeatureCutoff(0));
        }
        return spaces;
    }

    public static void main(String[] args) {
        new NLPTrain(args);
    }

    private class TrainTask
    implements Runnable {
        AbstractStatisticalComponent j_component;
        JointReader j_reader;

        public TrainTask(Element eConfig, String trainFile, AbstractStatisticalComponent component) {
            this.j_reader = NLPTrain.this.getJointReader(UTXml.getFirstElementByTagName(eConfig, "reader"));
            this.j_reader.open(UTInput.createBufferedFileReader(trainFile));
            this.j_component = component;
        }

        @Override
        public void run() {
            DEPTree tree;
            while ((tree = this.j_reader.next()) != null) {
                this.j_component.process(tree);
            }
            this.j_reader.close();
            System.out.print(".");
        }
    }
}

