package edu.stanford.nlp.stats;

import edu.stanford.nlp.util.BinaryHeapPriorityQueue;
import edu.stanford.nlp.util.FixedPrioritiesPriorityQueue;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableDouble;
import edu.stanford.nlp.util.PriorityQueue;
import edu.stanford.nlp.util.Sets;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/stats/Counters.class */
public class Counters {
    private Counters() {
    }

    public static <E> Counter<E> union(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        Counter<E> counter = new Counter<>();
        counter.addAll(genericCounter);
        counter.addAll(genericCounter2);
        return counter;
    }

    public static <E> Counter<E> intersection(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        Counter<E> counter = new Counter<>();
        for (E e : Sets.union(genericCounter.keySet(), genericCounter2.keySet())) {
            double count = genericCounter.getCount(e);
            double count2 = genericCounter2.getCount(e);
            double d = count < count2 ? count : count2;
            if (d > 0.0d) {
                counter.setCount((Counter<E>) e, d);
            }
        }
        return counter;
    }

    public static <E> double jaccardCoefficient(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (E e : Sets.union(genericCounter.keySet(), genericCounter2.keySet())) {
            double count = genericCounter.getCount(e);
            double count2 = genericCounter2.getCount(e);
            d += count < count2 ? count : count2;
            d2 += count > count2 ? count : count2;
        }
        return d / d2;
    }

    public static <E> Counter<E> product(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        Counter<E> counter = new Counter<>();
        for (E e : Sets.intersection(genericCounter.keySet(), genericCounter2.keySet())) {
            counter.setCount((Counter<E>) e, genericCounter.getCount(e) * genericCounter2.getCount(e));
        }
        return counter;
    }

    public static <E> double dotProduct(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        double d = 0.0d;
        for (E e : genericCounter.keySet()) {
            double count = genericCounter.getCount(e);
            if (Double.isNaN(count) || Double.isInfinite(count)) {
                throw new RuntimeException();
            }
            if (count != 0.0d) {
                double count2 = genericCounter2.getCount(e);
                if (Double.isNaN(count2) || Double.isInfinite(count2)) {
                    System.err.println("bad value: " + count2);
                    throw new RuntimeException();
                }
                if (count2 != 0.0d) {
                    d += count * count2;
                }
            }
        }
        return d;
    }

    public static <E> Counter<E> absoluteDifference(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        Counter<E> counter = new Counter<>();
        for (E e : Sets.union(genericCounter.keySet(), genericCounter2.keySet())) {
            double abs = Math.abs(genericCounter.getCount(e) - genericCounter2.getCount(e));
            if (abs > 0.0d) {
                counter.setCount((Counter<E>) e, abs);
            }
        }
        return counter;
    }

    public static <E> Counter<E> division(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        Counter<E> counter = new Counter<>();
        for (E e : Sets.union(genericCounter.keySet(), genericCounter2.keySet())) {
            counter.setCount((Counter<E>) e, genericCounter.getCount(e) / genericCounter2.getCount(e));
        }
        return counter;
    }

    public static <E> double entropy(GenericCounter<E> genericCounter) {
        double d = 0.0d;
        double d2 = genericCounter.totalDoubleCount();
        Iterator<E> it = genericCounter.keySet().iterator();
        while (it.hasNext()) {
            double count = genericCounter.getCount(it.next());
            if (count != 0.0d) {
                double d3 = count / d2;
                d -= d3 * (Math.log(d3) / Math.log(2.0d));
            }
        }
        return d;
    }

    public static <E> double crossEntropy(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        double d = genericCounter2.totalDoubleCount();
        double d2 = 0.0d;
        double log = Math.log(2.0d);
        for (E e : genericCounter.keySet()) {
            double count = genericCounter.getCount(e);
            if (count != 0.0d) {
                double log2 = Math.log(genericCounter2.getCount(e) / d);
                if (log2 == Double.NEGATIVE_INFINITY) {
                    return Double.NEGATIVE_INFINITY;
                }
                d2 += count * (log2 / log);
            }
        }
        return d2;
    }

    public static <E> double crossEntropy(GenericCounter<E> genericCounter, Counter<E> counter) {
        double d = 0.0d;
        double log = Math.log(2.0d);
        for (E e : genericCounter.keySet()) {
            double count = genericCounter.getCount(e);
            if (count != 0.0d) {
                double log2 = Math.log(counter.getCount(e));
                if (log2 == Double.NEGATIVE_INFINITY) {
                    return Double.NEGATIVE_INFINITY;
                }
                d += count * (log2 / log);
            }
        }
        return d;
    }

    public static <E> double klDivergence(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        double d = 0.0d;
        double d2 = genericCounter.totalDoubleCount();
        double d3 = genericCounter2.totalDoubleCount();
        double log = Math.log(2.0d);
        for (E e : genericCounter.keySet()) {
            double count = genericCounter.getCount(e);
            if (count != 0.0d) {
                double d4 = count / d2;
                double log2 = Math.log(d4 / (genericCounter2.getCount(e) / d3));
                if (log2 == Double.NEGATIVE_INFINITY) {
                    return Double.NEGATIVE_INFINITY;
                }
                d += d4 * (log2 / log);
            }
        }
        return d;
    }

    public static <E> double jensenShannonDivergence(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        Counter average = average(genericCounter, genericCounter2);
        return (klDivergence(genericCounter, average) + klDivergence(genericCounter2, average)) / 2.0d;
    }

    public static <E> double skewDivergence(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2, double d) {
        return klDivergence(genericCounter, linearCombination(genericCounter2, d, genericCounter, 1.0d - d));
    }

    public static <E> Counter<E> L2Normalize(GenericCounter<E> genericCounter) {
        double d = 0.0d;
        Iterator<E> it = genericCounter.keySet().iterator();
        while (it.hasNext()) {
            double count = genericCounter.getCount(it.next());
            if (count != 0.0d) {
                d += count * count;
            }
        }
        return scale(genericCounter, 1.0d / Math.sqrt(d));
    }

    public static <E> double cosine(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (E e : genericCounter.keySet()) {
            double count = genericCounter.getCount(e);
            if (count != 0.0d) {
                d2 += count * count;
                double count2 = genericCounter2.getCount(e);
                if (count2 != 0.0d) {
                    d += count * count2;
                }
            }
        }
        Iterator<E> it = genericCounter2.keySet().iterator();
        while (it.hasNext()) {
            double count3 = genericCounter2.getCount(it.next());
            if (count3 != 0.0d) {
                d3 += count3 * count3;
            }
        }
        if (d2 == 0.0d || d3 == 0.0d) {
            return 0.0d;
        }
        return d / (Math.sqrt(d2) * Math.sqrt(d3));
    }

    public static <E> Counter<E> average(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        Counter<E> counter = new Counter<>();
        HashSet hashSet = new HashSet(genericCounter.keySet());
        hashSet.addAll(genericCounter2.keySet());
        for (E e : hashSet) {
            counter.setCount((Counter<E>) e, (genericCounter.getCount(e) + genericCounter2.getCount(e)) * 0.5d);
        }
        return counter;
    }

    public static <E> Counter<E> linearCombination(GenericCounter<E> genericCounter, double d, GenericCounter<E> genericCounter2, double d2) {
        Counter<E> counter = new Counter<>();
        for (E e : genericCounter.keySet()) {
            counter.incrementCount(e, genericCounter.getCount(e) * d);
        }
        for (E e2 : genericCounter2.keySet()) {
            counter.incrementCount(e2, genericCounter2.getCount(e2) * d2);
        }
        return counter;
    }

    public static <E> Counter<E> perturbCounts(GenericCounter<E> genericCounter, Random random, double d) {
        Counter<E> counter = new Counter<>(genericCounter.getMapFactory());
        for (E e : genericCounter.keySet()) {
            counter.setCount((Counter<E>) e, genericCounter.getCount(e) + ((-Math.log(1.0d - random.nextDouble())) * d));
        }
        return counter;
    }

    public static <E> Counter<E> createCounterFromList(List<E> list) {
        return createCounterFromCollection(list);
    }

    public static <E> Counter<E> createCounterFromCollection(Collection<E> collection) {
        Counter<E> counter = new Counter<>();
        Iterator<E> it = collection.iterator();
        while (it.hasNext()) {
            counter.incrementCount(it.next());
        }
        return counter;
    }

    public static <E> List<E> toSortedList(GenericCounter<E> genericCounter) {
        ArrayList arrayList = new ArrayList(genericCounter.keySet());
        Collections.sort(arrayList, genericCounter.comparator());
        Collections.reverse(arrayList);
        return arrayList;
    }

    public static <E> PriorityQueue<E> toPriorityQueue(GenericCounter<E> genericCounter) {
        BinaryHeapPriorityQueue binaryHeapPriorityQueue = new BinaryHeapPriorityQueue();
        for (E e : genericCounter.keySet()) {
            binaryHeapPriorityQueue.add(e, genericCounter.getCount(e));
        }
        return binaryHeapPriorityQueue;
    }

    public static <E> void printCounterComparison(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2) {
        printCounterComparison(genericCounter, genericCounter2, System.err);
    }

    public static <E> void printCounterComparison(GenericCounter<E> genericCounter, GenericCounter<E> genericCounter2, PrintStream printStream) {
        if (genericCounter.equals(genericCounter2)) {
            printStream.println("Counters are equal.");
            return;
        }
        for (E e : genericCounter.keySet()) {
            if (Math.abs(genericCounter.getCount(e) - genericCounter2.getCount(e)) > 1.0E-5d) {
                printStream.println("Counters differ on key " + e + "\t" + genericCounter.getCountAsString(e) + " vs. " + genericCounter2.getCountAsString(e));
            }
        }
        HashSet hashSet = new HashSet(genericCounter2.keySet());
        hashSet.removeAll(genericCounter.keySet());
        for (E e2 : hashSet) {
            if (Math.abs(genericCounter.getCount(e2) - genericCounter2.getCount(e2)) > 1.0E-5d) {
                printStream.println("Counters differ on key " + e2 + "\t" + genericCounter.getCountAsString(e2) + " vs. " + genericCounter2.getCountAsString(e2));
            }
        }
    }

    public static <E> Counter<Double> getCountCounts(GenericCounter<E> genericCounter) {
        Counter<Double> counter = new Counter<>();
        Iterator<E> it = genericCounter.keySet().iterator();
        while (it.hasNext()) {
            counter.incrementCount(new Double(genericCounter.getCount(it.next())));
        }
        return counter;
    }

    public static <E> Counter<E> scale(GenericCounter<E> genericCounter, double d) {
        Counter<E> counter = new Counter<>(genericCounter.getMapFactory());
        for (E e : genericCounter.keySet()) {
            counter.setCount((Counter<E>) e, genericCounter.getCount(e) * d);
        }
        return counter;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <E extends Comparable<E>> void printCounterSortedByKeys(GenericCounter<E> genericCounter) {
        ArrayList<Comparable> arrayList = new ArrayList(genericCounter.keySet());
        Collections.sort(arrayList);
        for (Comparable comparable : arrayList) {
            System.out.println(comparable + ":" + genericCounter.getCountAsString(comparable));
        }
    }

    public static <E> Counter<E> loadCounter(String str, Class<E> cls) throws RuntimeException {
        Counter<E> counter = new Counter<>();
        loadIntoCounter(str, cls, counter);
        return counter;
    }

    public static <E> IntCounter<E> loadIntCounter(String str, Class<E> cls) throws Exception {
        IntCounter<E> intCounter = new IntCounter<>();
        loadIntoCounter(str, cls, intCounter);
        return intCounter;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static <E> void loadIntoCounter(String str, Class cls, GenericCounter<E> genericCounter) throws RuntimeException {
        try {
            Constructor constructor = cls.getConstructor(Class.forName("java.lang.String"));
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            for (String readLine = bufferedReader.readLine(); readLine != null && readLine.length() > 0; readLine = bufferedReader.readLine()) {
                String[] split = readLine.split("\\p{Space}+");
                genericCounter.setCount(constructor.newInstance(split[0]), split[1]);
            }
            bufferedReader.close();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static <E> void saveCounter(GenericCounter<E> genericCounter, String str) throws IOException {
        PrintWriter printWriter = new PrintWriter(new FileWriter(str));
        for (E e : genericCounter.keySet()) {
            printWriter.println(e + " " + genericCounter.getCountAsString(e));
        }
        printWriter.close();
    }

    public static void serializeCounter(GenericCounter genericCounter, String str) throws IOException {
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(str)));
        objectOutputStream.writeObject(genericCounter);
        objectOutputStream.close();
    }

    public static Counter deserializeCounter(String str) throws Exception {
        ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(str)));
        Counter counter = (Counter) objectInputStream.readObject();
        objectInputStream.close();
        return counter;
    }

    public static <E> String toBiggestValuesFirstString(Counter<E> counter) {
        return toPriorityQueue(counter).toString();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static <E> String toBiggestValuesFirstString(Counter<E> counter, int i) {
        PriorityQueue priorityQueue = toPriorityQueue(counter);
        BinaryHeapPriorityQueue binaryHeapPriorityQueue = new BinaryHeapPriorityQueue();
        while (binaryHeapPriorityQueue.size() < i && ((Iterator) priorityQueue).hasNext()) {
            binaryHeapPriorityQueue.changePriority(priorityQueue.removeFirst(), priorityQueue.getPriority(priorityQueue.getFirst()));
        }
        return binaryHeapPriorityQueue.toString();
    }

    public static <E> String toVerticalString(Counter<E> counter) {
        return toVerticalString(counter, Integer.MAX_VALUE);
    }

    public static <E> String toVerticalString(Counter<E> counter, int i) {
        return toVerticalString(counter, i, "%g\t%s", false);
    }

    public static <E> String toVerticalString(Counter<E> counter, String str) {
        return toVerticalString(counter, Integer.MAX_VALUE, str, false);
    }

    public static <E> String toVerticalString(Counter<E> counter, int i, String str) {
        return toVerticalString(counter, i, str, false);
    }

    public static <E> String toVerticalString(Counter<E> counter, int i, String str, boolean z) {
        PriorityQueue priorityQueue = toPriorityQueue(counter);
        List<E> sortedList = priorityQueue.toSortedList();
        StringBuilder sb = new StringBuilder();
        Iterator<E> it = sortedList.iterator();
        for (int i2 = 0; it.hasNext() && i2 < i; i2++) {
            E next = it.next();
            double priority = priorityQueue.getPriority(next);
            if (z) {
                sb.append(String.format(str, next, Double.valueOf(priority)));
            } else {
                sb.append(String.format(str, Double.valueOf(priority), next));
            }
            if (it.hasNext()) {
                sb.append("\n");
            }
        }
        return sb.toString();
    }

    public static <E> Object restrictedArgMax(Counter<E> counter, Collection<E> collection) {
        E e = null;
        double d = Double.NEGATIVE_INFINITY;
        for (E e2 : collection) {
            double count = counter.getCount(e2);
            if (count > d) {
                d = count;
                e = e2;
            }
        }
        return e;
    }

    public static <T> Counter<T> toCounter(double[] dArr, Index<T> index) {
        if (index.size() < dArr.length) {
            throw new IllegalArgumentException("Index not large enough to name all the array elements!");
        }
        Counter<T> counter = new Counter<>();
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] != 0.0d) {
                counter.setCount((Counter<T>) index.get(i), dArr[i]);
            }
        }
        return counter;
    }

    public static <T1, T2> TwoDimensionalCounter<T1, T2> scale(TwoDimensionalCounter<T1, T2> twoDimensionalCounter, double d) {
        TwoDimensionalCounter<T1, T2> twoDimensionalCounter2 = new TwoDimensionalCounter<>(twoDimensionalCounter.getOuterMapFactory(), twoDimensionalCounter.getInnerMapFactory());
        for (T1 t1 : twoDimensionalCounter.firstKeySet()) {
            twoDimensionalCounter2.setCounter(t1, scale(twoDimensionalCounter.getCounter(t1), d));
        }
        return twoDimensionalCounter2;
    }

    public static <T> T sample(Counter<T> counter, Random random) {
        double nextDouble = random.nextDouble() * counter.totalCount();
        double d = 0.0d;
        for (T t : counter.keySet()) {
            d += counter.getCount(t);
            if (d >= nextDouble) {
                return t;
            }
        }
        return counter.keySet().iterator().next();
    }

    public static <T> T sample(Counter<T> counter) {
        return (T) sample(counter, new Random());
    }

    public static <T> Counter<T> powNormalized(Counter<T> counter, double d) {
        Counter<T> counter2 = new Counter<>();
        for (T t : counter.keySet()) {
            counter2.setCount((Counter<T>) t, Math.pow(counter.getNormalizedCount(t), d));
        }
        return counter2;
    }

    public static <T> Counter<T> pow(Counter<T> counter, double d) {
        Counter<T> counter2 = new Counter<>();
        for (T t : counter.keySet()) {
            counter2.setCount((Counter<T>) t, Math.pow(counter.getCount(t), d));
        }
        return counter2;
    }

    public static <T> Counter<T> exp(Counter<T> counter) {
        Counter<T> counter2 = new Counter<>();
        for (T t : counter.keySet()) {
            counter2.setCount((Counter<T>) t, Math.exp(counter.getCount(t)));
        }
        return counter2;
    }

    public static <T> Counter<T> diff(Counter<T> counter, Counter<T> counter2) {
        Counter<T> counter3 = new Counter<>(counter);
        counter3.subtractAll(counter2, true);
        return counter3;
    }

    public static <T> GenericCounter<T> unmodifiableCounter(final GenericCounter<T> genericCounter) {
        return new GenericCounter<T>() { // from class: edu.stanford.nlp.stats.Counters.1
            @Override // edu.stanford.nlp.stats.GenericCounter
            public Comparator<T> comparator() {
                return GenericCounter.this.comparator();
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public boolean containsKey(T t) {
                return GenericCounter.this.containsKey(t);
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public double doubleMax() {
                return GenericCounter.this.doubleMax();
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public double getCount(T t) {
                return GenericCounter.this.getCount(t);
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public String getCountAsString(T t) {
                return GenericCounter.this.getCountAsString(t);
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public MapFactory<T, MutableDouble> getMapFactory() {
                return GenericCounter.this.getMapFactory();
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public Set<T> keySet() {
                return GenericCounter.this.keySet();
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public void setCount(T t, String str) {
                throw new UnsupportedOperationException();
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public int size() {
                return GenericCounter.this.size();
            }

            @Override // edu.stanford.nlp.stats.GenericCounter
            public double totalDoubleCount() {
                return GenericCounter.this.totalDoubleCount();
            }
        };
    }

    public static <E> Counter<E> asCounter(FixedPrioritiesPriorityQueue<E> fixedPrioritiesPriorityQueue) {
        FixedPrioritiesPriorityQueue<E> m134clone = fixedPrioritiesPriorityQueue.m134clone();
        Counter<E> counter = new Counter<>();
        while (m134clone.hasNext()) {
            counter.incrementCount(m134clone.next(), m134clone.getPriority());
        }
        return counter;
    }
}
