package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/stanford/nlp/parser/lexparser/ChineseSimWordAvgDepGrammar.class */
public class ChineseSimWordAvgDepGrammar extends MLEDependencyGrammar {
    private double simSmooth;
    private static final String argHeadFile = "simWords/ArgHead.5";
    private static final String headArgFile = "simWords/HeadArg.5";
    private Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> simArgMap;
    private Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> simHeadMap;
    private Lexicon lex;
    private boolean debug;
    private boolean verbose;
    private Counter<String> statsCounter;

    public ChineseSimWordAvgDepGrammar(TreebankLangParserParams treebankLangParserParams, boolean z, boolean z2, boolean z3) {
        super(treebankLangParserParams, z, z2, z3);
        this.simSmooth = 10.0d;
        this.debug = true;
        this.verbose = false;
        this.statsCounter = new Counter<>();
        this.simHeadMap = getMap(headArgFile);
        this.simArgMap = getMap(argHeadFile);
    }

    public static Map<Pair<Integer, String>, List<Triple<Integer, String, Double>>> getMap(String str) {
        HashMap hashMap = new HashMap();
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str), "UTF-8"));
            Pattern compile = Pattern.compile("sim\\((.+)/(.+):(.+)/(.+)\\)=(.+)");
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return hashMap;
                }
                Matcher matcher = compile.matcher(readLine);
                if (matcher.matches()) {
                    Pair pair = new Pair(Integer.valueOf(wordNumberer().number(matcher.group(1))), matcher.group(2));
                    double parseDouble = Double.parseDouble(matcher.group(5));
                    List list = (List) hashMap.get(pair);
                    if (list == null) {
                        list = new ArrayList();
                        hashMap.put(pair, list);
                    }
                    list.add(new Triple(Integer.valueOf(wordNumberer().number(matcher.group(3))), matcher.group(4), Double.valueOf(parseDouble)));
                } else {
                    System.err.println("Ill-formed line in similar word map file: " + readLine);
                }
            }
        } catch (IOException e) {
            throw new RuntimeException("Problem reading similar words file!");
        }
    }

    @Override // edu.stanford.nlp.parser.lexparser.MLEDependencyGrammar, edu.stanford.nlp.parser.lexparser.DependencyGrammar
    public double scoreTB(IntDependency intDependency) {
        return Test.depWeight * Math.log(probTBwithSimWords(intDependency));
    }

    public void setLex(Lexicon lexicon) {
        this.lex = lexicon;
    }

    protected void finalize() throws Throwable {
        super.finalize();
        System.err.println("SimWordAvg stats:");
        System.err.println(this.statsCounter);
    }

    private double probTBwithSimWords(IntDependency intDependency) {
        if (!this.directional) {
            intDependency.leftHeaded = false;
        }
        if (this.verbose) {
            System.out.println("Generating " + intDependency);
        }
        short s = intDependency.distance;
        boolean z = intDependency.leftHeaded;
        int i = intDependency.head.word;
        int i2 = intDependency.arg.word;
        IntTaggedWord intTaggedWord = intDependency.arg;
        IntTaggedWord intTaggedWord2 = intDependency.head;
        double stopProb = getStopProb(intDependency);
        boolean rootTW = rootTW(intDependency.head);
        if (intDependency.arg.word == -2) {
            if (rootTW) {
                return 0.0d;
            }
            return stopProb;
        }
        double d = 1.0d - stopProb;
        if (rootTW) {
            d = 1.0d;
        }
        intDependency.distance = valenceBin(s);
        double count = this.argCounter.getCount(intDependency);
        intDependency.arg.word = -1;
        double count2 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = i2;
        intDependency.arg = wildTW;
        double count3 = this.argCounter.getCount(intDependency);
        intDependency.arg = intTaggedWord;
        intDependency.head.word = -1;
        double count4 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = -1;
        double count5 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = i2;
        intDependency.arg = wildTW;
        double count6 = this.argCounter.getCount(intDependency);
        intDependency.arg = intTaggedWord;
        intDependency.head.word = i;
        intDependency.head = wildTW;
        intDependency.leftHeaded = false;
        intDependency.distance = (short) -1;
        double count7 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = -1;
        double count8 = this.argCounter.getCount(intDependency);
        intDependency.arg.word = i2;
        intDependency.leftHeaded = z;
        intDependency.head = intTaggedWord2;
        intDependency.distance = s;
        double d2 = count6 > 0.0d ? count4 / count6 : 0.0d;
        double d3 = count6 > 0.0d ? count5 / count6 : 0.0d;
        double d4 = count7 > 0.0d ? count7 / count8 : 1.0d;
        double d5 = (count + (this.smooth_aTW_hTWd * d2)) / (count3 + this.smooth_aTW_hTWd);
        double d6 = (count2 + (this.smooth_aT_hTWd * d3)) / (count3 + this.smooth_aT_hTWd);
        double d7 = ((this.interp * d5) + ((1.0d - this.interp) * d4 * d6)) * d;
        List<Triple<Integer, String, Double>> list = this.simArgMap.get(new Pair(Integer.valueOf(intDependency.arg.word), stringBasicCategory(intDependency.arg.tag)));
        List<Triple<Integer, String, Double>> list2 = this.simHeadMap.get(new Pair(Integer.valueOf(intDependency.head.word), stringBasicCategory(intDependency.head.tag)));
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        if (list != null) {
            Iterator<Triple<Integer, String, Double>> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().first);
            }
        }
        if (list2 != null) {
            Iterator<Triple<Integer, String, Double>> it2 = list2.iterator();
            while (it2.hasNext()) {
                arrayList2.add(it2.next().first);
            }
        }
        double d8 = 0.0d;
        double d9 = 0.0d;
        Iterator it3 = arrayList2.iterator();
        while (it3.hasNext()) {
            int intValue = ((Integer) it3.next()).intValue();
            intDependency.arg = intTaggedWord;
            intDependency.head = intTaggedWord2;
            intDependency.head.word = intValue;
            d8 += this.argCounter.getCount(intDependency);
            intDependency.arg = wildTW;
            d9 += this.argCounter.getCount(intDependency);
        }
        intDependency.arg = intTaggedWord;
        intDependency.head = intTaggedWord2;
        double d10 = d9 > 0.0d ? d8 / d9 : 0.0d;
        if (this.debug && d10 > 0.0d) {
            System.out.println(intDependency + "\t" + d10);
        }
        double d11 = ((count + (17.7d * d10)) + (35.4d * d2)) / ((count3 + 17.7d) + 35.4d);
        System.out.println(intDependency);
        System.out.println(count + " + 17.7 * " + d10 + " + 35.4 * " + d2);
        System.out.println("--------------------------------  = " + d11);
        System.out.println(count3 + " + 17.7 + 35.4");
        System.out.println();
        double d12 = ((this.interp * d11) + ((1.0d - this.interp) * d4 * d6)) * d;
        if (this.verbose) {
            NumberFormat numberInstance = NumberFormat.getNumberInstance();
            numberInstance.setMaximumFractionDigits(2);
            System.out.println("  c_aTW_hTWd: " + count + "; c_aT_hTWd: " + count2 + "; c_hTWd: " + count3);
            System.out.println("  c_aTW_hTd: " + count4 + "; c_aT_hTd: " + count5 + "; c_hTd: " + count6);
            System.out.println("  Generated with pb_go_hTWds: " + numberInstance.format(d) + " pb_aTW_hTWd: " + numberInstance.format(d11) + " p_aTW_aT: " + numberInstance.format(d4) + " pb_aT_hTWd: " + numberInstance.format(d6));
            System.out.println("  NoDist score: " + d12);
        }
        if (Test.prunePunc && pruneTW(intTaggedWord)) {
            return 1.0d;
        }
        if (Double.isNaN(d12)) {
            d12 = 0.0d;
        }
        if (d12 < 1.0E-40d) {
            d12 = 0.0d;
        }
        return d12;
    }

    private double probSimilarWordAvg(IntDependency intDependency) {
        double probTB = probTB(intDependency);
        this.statsCounter.incrementCount("total");
        List<Triple<Integer, String, Double>> list = this.simArgMap.get(new Pair(Integer.valueOf(intDependency.arg.word), stringBasicCategory(intDependency.arg.tag)));
        List<Triple<Integer, String, Double>> list2 = this.simHeadMap.get(new Pair(Integer.valueOf(intDependency.head.word), stringBasicCategory(intDependency.head.tag)));
        if (list2 == null && list == null) {
            return probTB;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        if (list2 == null) {
            IntTaggedWord intTaggedWord = intDependency.arg;
            this.statsCounter.incrementCount("aSim");
            for (Triple<Integer, String, Double> triple : list) {
                double exp = Math.exp((-50.0d) * triple.third.doubleValue());
                int i = tagNumberer().total();
                for (int i2 = 0; i2 < i; i2++) {
                    if (stringBasicCategory(i2).equals(triple.second)) {
                        intDependency.arg = new IntTaggedWord(triple.first.intValue(), i2);
                        double exp2 = Math.exp(this.lex.score(intDependency.arg, 0));
                        if (exp2 != 0.0d) {
                            d += (probTB(intDependency) * exp) / exp2;
                            d2 += exp;
                        }
                    }
                }
            }
            intDependency.arg = intTaggedWord;
        } else if (list == null) {
            IntTaggedWord intTaggedWord2 = intDependency.head;
            this.statsCounter.incrementCount("hSim");
            for (Triple<Integer, String, Double> triple2 : list2) {
                double exp3 = Math.exp((-50.0d) * triple2.third.doubleValue());
                int i3 = tagNumberer().total();
                for (int i4 = 0; i4 < i3; i4++) {
                    if (stringBasicCategory(i4).equals(triple2.second)) {
                        intDependency.head = new IntTaggedWord(triple2.first.intValue(), i4);
                        d += probTB(intDependency) * exp3;
                        d2 += exp3;
                    }
                }
            }
            intDependency.head = intTaggedWord2;
        } else {
            IntTaggedWord intTaggedWord3 = intDependency.head;
            IntTaggedWord intTaggedWord4 = intDependency.arg;
            this.statsCounter.incrementCount("hSim");
            this.statsCounter.incrementCount("aSim");
            this.statsCounter.incrementCount("aSim&hSim");
            for (Triple<Integer, String, Double> triple3 : list) {
                int i5 = tagNumberer().total();
                for (int i6 = 0; i6 < i5; i6++) {
                    if (stringBasicCategory(i6).equals(triple3.second)) {
                        intDependency.arg = new IntTaggedWord(triple3.first.intValue(), i6);
                        double exp4 = Math.exp(this.lex.score(intDependency.arg, 0));
                        if (exp4 != 0.0d) {
                            for (Triple<Integer, String, Double> triple4 : list2) {
                                for (int i7 = 0; i7 < i5; i7++) {
                                    if (stringBasicCategory(i7).equals(triple4.second)) {
                                        intDependency.head = new IntTaggedWord(triple4.first.intValue(), i6);
                                        double exp5 = Math.exp((-50.0d) * triple4.third.doubleValue()) * Math.exp((-50.0d) * triple3.third.doubleValue());
                                        d += (probTB(intDependency) * exp5) / exp4;
                                        d2 += exp5;
                                    }
                                }
                            }
                        }
                    }
                }
            }
            intDependency.head = intTaggedWord3;
            intDependency.arg = intTaggedWord4;
        }
        IntTaggedWord intTaggedWord5 = intDependency.arg;
        intDependency.arg = wildTW;
        double count = this.argCounter.getCount(intDependency);
        intDependency.arg = intTaggedWord5;
        double exp6 = list == null ? d / d2 : (Math.exp(this.lex.score(intDependency.arg, 0)) * d) / d2;
        if (exp6 == 0.0d) {
            this.statsCounter.incrementCount("simProbZero");
        }
        if (probTB == 0.0d) {
            this.statsCounter.incrementCount("regProbZero");
        }
        double d3 = ((count * probTB) + (this.simSmooth * exp6)) / (count + this.simSmooth);
        if (d3 == 0.0d) {
            this.statsCounter.incrementCount("smoothProbZero");
        }
        return d3;
    }

    private String stringBasicCategory(int i) {
        return this.tlp.basicCategory((String) tagNumberer().object(i));
    }

    static {
        System.runFinalizersOnExit(true);
    }
}
