package edu.stanford.nlp.stats;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.trees.international.negra.NegraLabel;
import edu.stanford.nlp.util.BinaryHeapPriorityQueue;
import edu.stanford.nlp.util.EntryValueComparator;
import edu.stanford.nlp.util.Filter;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.MutableDouble;
import edu.stanford.nlp.util.PriorityQueue;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/stats/Counter.class */
public class Counter<E> implements Serializable, GenericCounter<E>, Iterable<E> {
    Map<E, MutableDouble> map;
    MapFactory<E, MutableDouble> mapFactory;
    private double totalCount;
    private static final Comparator hashCodeComparator = new Comparator<Object>() { // from class: edu.stanford.nlp.stats.Counter.1
        @Override // java.util.Comparator
        public int compare(Object obj, Object obj2) {
            return obj.hashCode() - obj2.hashCode();
        }

        public boolean equals(Comparator comparator) {
            return comparator == this;
        }
    };
    private static final long serialVersionUID = 4;
    private transient MutableDouble tempMDouble;

    public Counter() {
        this(MapFactory.HASH_MAP_FACTORY);
    }

    public Counter(MapFactory<E, MutableDouble> mapFactory) {
        this.tempMDouble = null;
        this.mapFactory = mapFactory;
        this.map = mapFactory.newMap();
        this.totalCount = 0.0d;
    }

    public Counter(GenericCounter<E> genericCounter) {
        this();
        addAll(genericCounter);
    }

    public Counter(Collection<E> collection) {
        this();
        addAll(collection);
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public MapFactory<E, MutableDouble> getMapFactory() {
        return this.mapFactory;
    }

    public double totalCount() {
        return this.totalCount;
    }

    @Override // java.lang.Iterable
    public Iterator<E> iterator() {
        return keySet().iterator();
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public double totalDoubleCount() {
        return totalCount();
    }

    public double totalCount(Filter<E> filter) {
        double d = 0.0d;
        for (E e : this.map.keySet()) {
            if (filter.accept(e)) {
                d += getCount(e);
            }
        }
        return d;
    }

    public double logSum() {
        double[] dArr = new double[this.map.size()];
        int i = 0;
        Iterator<E> it = this.map.keySet().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = getCount(it.next());
        }
        return ArrayMath.logSum(dArr);
    }

    public void logNormalize() {
        incrementAll(-logSum());
    }

    public double averageCount() {
        return totalCount() / this.map.size();
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public double getCount(E e) {
        MutableDouble mutableDouble = this.map.get(e);
        if (mutableDouble == null) {
            return 0.0d;
        }
        return mutableDouble.doubleValue();
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public String getCountAsString(E e) {
        return Double.toString(getCount(e));
    }

    public double getNormalizedCount(E e) {
        return getCount(e) / totalCount();
    }

    public void setCount(E e, double d) {
        if (this.tempMDouble == null) {
            this.tempMDouble = new MutableDouble();
        }
        this.tempMDouble.set(d);
        this.tempMDouble = this.map.put(e, this.tempMDouble);
        this.totalCount += d;
        if (this.tempMDouble != null) {
            this.totalCount -= this.tempMDouble.doubleValue();
        }
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public void setCount(E e, String str) {
        setCount((Counter<E>) e, Double.parseDouble(str));
    }

    public void setCounts(Collection<E> collection, double d) {
        Iterator<E> it = collection.iterator();
        while (it.hasNext()) {
            setCount((Counter<E>) it.next(), d);
        }
    }

    public double incrementCount(E e, double d) {
        if (this.tempMDouble == null) {
            this.tempMDouble = new MutableDouble();
        }
        MutableDouble put = this.map.put(e, this.tempMDouble);
        this.totalCount += d;
        if (put != null) {
            d += put.doubleValue();
        }
        this.tempMDouble.set(d);
        this.tempMDouble = put;
        return d;
    }

    public double logIncrementCount(E e, double d) {
        if (this.tempMDouble == null) {
            this.tempMDouble = new MutableDouble();
        }
        MutableDouble put = this.map.put(e, this.tempMDouble);
        if (put != null) {
            d = SloppyMath.logAdd(d, put.doubleValue());
            this.totalCount += d - put.doubleValue();
        } else {
            this.totalCount += d;
        }
        this.tempMDouble.set(d);
        this.tempMDouble = put;
        return d;
    }

    public double incrementCount(E e) {
        return incrementCount(e, 1.0d);
    }

    public void incrementCounts(Collection<E> collection, double d) {
        Iterator<E> it = collection.iterator();
        while (it.hasNext()) {
            incrementCount(it.next(), d);
        }
    }

    public void incrementCounts(Collection<E> collection) {
        incrementCounts(collection, 1.0d);
    }

    public void incrementAll(double d) {
        for (MutableDouble mutableDouble : this.map.values()) {
            mutableDouble.set(mutableDouble.doubleValue() + d);
            this.totalCount += d;
        }
    }

    public double decrementCount(E e, double d) {
        return incrementCount(e, -d);
    }

    public double decrementCount(E e) {
        return decrementCount(e, 1.0d);
    }

    public void decrementCounts(Collection<E> collection, double d) {
        incrementCounts(collection, -d);
    }

    public void decrementCounts(Collection<E> collection) {
        decrementCounts(collection, 1.0d);
    }

    public void addAll(GenericCounter<E> genericCounter) {
        for (E e : genericCounter.keySet()) {
            incrementCount(e, genericCounter.getCount(e));
        }
    }

    public void addMultiple(GenericCounter<E> genericCounter, double d) {
        for (E e : genericCounter.keySet()) {
            incrementCount(e, genericCounter.getCount(e) * d);
            if (getCount(e) == 0.0d) {
                remove(e);
            }
        }
    }

    public void subtractAll(GenericCounter<E> genericCounter) {
        for (E e : genericCounter.keySet()) {
            incrementCount(e, -genericCounter.getCount(e));
        }
    }

    public void subtractMultiple(GenericCounter<E> genericCounter, double d) {
        for (E e : genericCounter.keySet()) {
            incrementCount(e, (-genericCounter.getCount(e)) * d);
            if (getCount(e) == 0.0d) {
                remove(e);
            }
        }
    }

    public void addAll(Collection<E> collection) {
        Iterator<E> it = collection.iterator();
        while (it.hasNext()) {
            incrementCount(it.next());
        }
    }

    public void multiplyBy(double d) {
        for (E e : this.map.keySet()) {
            setCount((Counter<E>) e, getCount(e) * d);
        }
    }

    public void divideBy(double d) {
        for (E e : this.map.keySet()) {
            setCount((Counter<E>) e, getCount(e) / d);
        }
    }

    public void divideBy(Counter<E> counter) {
        for (E e : this.map.keySet()) {
            setCount((Counter<E>) e, getCount(e) / counter.getCount(e));
        }
    }

    public void subtractAll(GenericCounter<E> genericCounter, boolean z) {
        for (E e : genericCounter.keySet()) {
            decrementCount(e, genericCounter.getCount(e));
            if (z && getCount(e) == 0.0d) {
                remove(e);
            }
        }
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public boolean containsKey(E e) {
        return this.map.containsKey(e);
    }

    public MutableDouble remove(E e) {
        MutableDouble remove = this.map.remove(e);
        if (remove != null) {
            this.totalCount -= remove.doubleValue();
        }
        return remove;
    }

    public void removeAll(Collection<E> collection) {
        Iterator<E> it = collection.iterator();
        while (it.hasNext()) {
            remove(it.next());
        }
    }

    public void clear() {
        this.map.clear();
        this.totalCount = 0.0d;
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public int size() {
        return this.map.size();
    }

    public boolean isEmpty() {
        return size() == 0;
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public Set<E> keySet() {
        return this.map.keySet();
    }

    public Set<Map.Entry<E, MutableDouble>> entrySet() {
        return this.map.entrySet();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof Counter)) {
            return false;
        }
        Counter counter = (Counter) obj;
        if (this.totalCount != counter.totalCount) {
            return false;
        }
        return this.map.equals(counter.map);
    }

    public int hashCode() {
        return this.map.hashCode();
    }

    public String toString() {
        return this.map.toString();
    }

    public String toString(int i) {
        return asBinaryHeapPriorityQueue().toString(i);
    }

    public String toString(NumberFormat numberFormat, String str, String str2, String str3, String str4) {
        StringBuilder sb = new StringBuilder();
        sb.append(str);
        Iterator<E> it = this.map.keySet().iterator();
        while (it.hasNext()) {
            E next = it.next();
            MutableDouble mutableDouble = this.map.get(next);
            sb.append(next);
            sb.append(str3);
            sb.append(numberFormat.format(mutableDouble));
            if (it.hasNext()) {
                sb.append(str4);
            }
        }
        sb.append(str2);
        return sb.toString();
    }

    public String toString(NumberFormat numberFormat) {
        StringBuilder sb = new StringBuilder();
        sb.append("{");
        ArrayList arrayList = new ArrayList(this.map.keySet());
        try {
            Collections.sort(arrayList);
        } catch (Exception e) {
        }
        Iterator<E> it = arrayList.iterator();
        while (it.hasNext()) {
            E next = it.next();
            MutableDouble mutableDouble = this.map.get(next);
            sb.append(next);
            sb.append("=");
            sb.append(numberFormat.format(mutableDouble));
            if (it.hasNext()) {
                sb.append(", ");
            }
        }
        sb.append("}");
        return sb.toString();
    }

    public Object clone() {
        return new Counter(this);
    }

    public void normalize() {
        double d = totalCount();
        if (d == 0.0d || Double.isNaN(d) || d == Double.NEGATIVE_INFINITY || d == Double.POSITIVE_INFINITY) {
            throw new RuntimeException("Can't normalize with bad total: " + d);
        }
        for (E e : this.map.keySet()) {
            setCount((Counter<E>) e, getCount(e) / d);
        }
    }

    public void removeZeroCounts() {
        Iterator<E> it = this.map.keySet().iterator();
        while (it.hasNext()) {
            if (getCount(it.next()) == 0.0d) {
                it.remove();
            }
        }
    }

    public PriorityQueue<E> asPriorityQueue() {
        return asBinaryHeapPriorityQueue();
    }

    public BinaryHeapPriorityQueue<E> asBinaryHeapPriorityQueue() {
        BinaryHeapPriorityQueue<E> binaryHeapPriorityQueue = new BinaryHeapPriorityQueue<>();
        for (Map.Entry<E, MutableDouble> entry : this.map.entrySet()) {
            binaryHeapPriorityQueue.add(entry.getKey(), entry.getValue().doubleValue());
        }
        return binaryHeapPriorityQueue;
    }

    public double max() {
        double d = Double.NEGATIVE_INFINITY;
        Iterator<E> it = this.map.keySet().iterator();
        while (it.hasNext()) {
            d = Math.max(d, getCount(it.next()));
        }
        return d;
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public double doubleMax() {
        return max();
    }

    public double min() {
        double d = Double.POSITIVE_INFINITY;
        Iterator<E> it = this.map.keySet().iterator();
        while (it.hasNext()) {
            d = Math.min(d, getCount(it.next()));
        }
        return d;
    }

    public E argmax(Comparator<E> comparator) {
        double d = Double.NEGATIVE_INFINITY;
        E e = null;
        for (E e2 : this.map.keySet()) {
            double count = getCount(e2);
            if (e == null || count > d) {
                d = count;
                e = e2;
            }
        }
        return e;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void retainTop(int i) {
        int size = size() - i;
        if (size <= 0) {
            return;
        }
        List sortedList = Counters.toSortedList(this);
        Collections.reverse(sortedList);
        for (int i2 = 0; i2 < size; i2++) {
            remove(sortedList.get(i2));
        }
    }

    public E argmax() {
        return argmax(hashCodeComparator);
    }

    public E argmin(Comparator<E> comparator) {
        double d = Double.POSITIVE_INFINITY;
        E e = null;
        for (E e2 : this.map.keySet()) {
            double count = getCount(e2);
            if (e == null || count < d) {
                d = count;
                e = e2;
            }
        }
        return e;
    }

    public E argmin() {
        return argmin(hashCodeComparator);
    }

    public Set<E> keysAbove(double d) {
        HashSet hashSet = new HashSet();
        for (E e : this.map.keySet()) {
            if (getCount(e) >= d) {
                hashSet.add(e);
            }
        }
        return hashSet;
    }

    public Set<E> keysBelow(double d) {
        HashSet hashSet = new HashSet();
        for (E e : this.map.keySet()) {
            if (getCount(e) <= d) {
                hashSet.add(e);
            }
        }
        return hashSet;
    }

    public Set<E> keysAt(double d) {
        HashSet hashSet = new HashSet();
        for (E e : this.map.keySet()) {
            if (getCount(e) == d) {
                hashSet.add(e);
            }
        }
        return hashSet;
    }

    public Comparator<E> comparator(boolean z) {
        return new EntryValueComparator(this.map, z);
    }

    public Comparator<E> comparator(boolean z, boolean z2) {
        return new EntryValueComparator(this.map, z, z2);
    }

    @Override // edu.stanford.nlp.stats.GenericCounter
    public Comparator<E> comparator() {
        return comparator(true);
    }

    public static Counter<String> valueOf(String str) {
        Counter<String> counter = new Counter<>();
        for (String str2 : str.split("\n")) {
            String[] split = str2.split("\t");
            if (split.length != 2) {
                throw new RuntimeException("Got unsplittable line: \"" + str2 + "\"");
            }
            counter.setCount((Counter<String>) split[0], Double.parseDouble(split[1]));
        }
        return counter;
    }

    public static Counter<String> valueOfIgnoreComments(String str) {
        Counter<String> counter = new Counter<>();
        for (String str2 : str.split("\n")) {
            if (!str2.startsWith(NegraLabel.FEATURE_SEP)) {
                String[] split = str2.split("\t");
                if (split.length != 2) {
                    throw new RuntimeException("Got unsplittable line: \"" + str2 + "\"");
                }
                counter.setCount((Counter<String>) split[0], Double.parseDouble(split[1]));
            }
        }
        return counter;
    }

    public static Counter<String> fromString(String str) {
        Counter<String> counter = new Counter<>();
        if (!str.startsWith("{") || !str.endsWith("}")) {
            throw new RuntimeException("invalid format: ||" + str + "||");
        }
        for (String str2 : str.substring(1, str.length() - 1).split(", ")) {
            String[] split = str2.split("=");
            if (split.length != 2) {
                throw new RuntimeException("Got unsplittable line: \"" + str2 + "\"");
            }
            counter.setCount((Counter<String>) split[0], Double.parseDouble(split[1]));
        }
        return counter;
    }

    public static void main(String[] strArr) throws Exception {
        Counter counter = new Counter();
        counter.setCount((Counter) "p", 0.0d);
        counter.setCount((Counter) "q", 2.0d);
        System.out.println(counter + " -> " + counter.totalCount() + " should be {p=0.0, q=2.0} -> 2.0");
        counter.incrementCount("p");
        System.out.println(counter + " -> " + counter.totalCount() + " should be {p=1.0, q=2.0} -> 3.0");
        counter.incrementCount("p", 2.0d);
        System.out.println(counter.min() + " " + ((String) counter.argmin()) + " should be 2.0 q");
        counter.setCount((Counter) "w", -5.0d);
        counter.setCount((Counter) "x", -2.5d);
        ArrayList arrayList = new ArrayList(counter.keySet());
        Collections.sort(arrayList, counter.comparator(false, true));
        System.out.println(arrayList + " should be [w, p, x, q]");
        System.out.println(counter + " should be {p=3.0, q=2.0, w=-5.0, x=-2.5}");
        if (strArr.length > 0) {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(strArr[0])));
            objectOutputStream.writeObject(counter);
            objectOutputStream.close();
            ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(strArr[0])));
            Counter counter2 = (Counter) objectInputStream.readObject();
            objectInputStream.close();
            System.out.println(counter2 + " -> " + counter2.totalCount() + " should be same -> -2.5");
            System.out.println(counter2.min() + " " + ((String) counter2.argmin()) + " should be -5 w");
            counter2.clear();
            System.out.println(counter2 + " -> " + counter2.totalCount() + " should be {} -> 0");
        }
    }
}
