package com.xiaomi.ai.nlp.ml.infer;

import com.xiaomi.ai.nlp.ml.base.MLMath;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: classes3.dex */
public class MultinomialLogisticRegression {

    /* renamed from: a, reason: collision with root package name */
    private Map<String, LabelWeights> f13831a = new HashMap();

    /* renamed from: b, reason: collision with root package name */
    private Map<String, Integer> f13832b = new HashMap();

    /* renamed from: c, reason: collision with root package name */
    private Map<Integer, String> f13833c = new HashMap();

    /* loaded from: classes3.dex */
    private class LabelWeights {

        /* renamed from: b, reason: collision with root package name */
        private Map<String, Double> f13835b = new HashMap();

        public LabelWeights() {
        }

        public void put(String str, double d2) {
            this.f13835b.put(str, Double.valueOf(d2));
        }

        public Map<String, Double> weigths() {
            return this.f13835b;
        }
    }

    /* loaded from: classes3.dex */
    public static class ProbInfo implements Comparable<ProbInfo> {

        /* renamed from: a, reason: collision with root package name */
        private String f13836a;

        /* renamed from: b, reason: collision with root package name */
        private double f13837b;

        @Override // java.lang.Comparable
        public int compareTo(ProbInfo probInfo) {
            return -Double.compare(this.f13837b, probInfo.f13837b);
        }

        public String getLabel() {
            return this.f13836a;
        }

        public double getProb() {
            return this.f13837b;
        }

        public void setLabel(String str) {
            this.f13836a = str;
        }

        public void setProb(double d2) {
            this.f13837b = d2;
        }
    }

    public List<ProbInfo> infer(Map<String, Double> map) {
        int size = this.f13832b.size();
        double[] dArr = new double[size];
        for (Map.Entry<String, Double> entry : map.entrySet()) {
            if (this.f13831a.containsKey(entry.getKey())) {
                for (Map.Entry entry2 : this.f13831a.get(entry.getKey()).f13835b.entrySet()) {
                    if (!this.f13832b.containsKey(entry2.getKey())) {
                        throw new IllegalArgumentException("label set error, new label find: " + ((String) entry2.getKey()));
                    }
                    int intValue = this.f13832b.get(entry2.getKey()).intValue();
                    dArr[intValue] = dArr[intValue] + (((Double) entry2.getValue()).doubleValue() * entry.getValue().doubleValue());
                }
            }
        }
        double logSumExp = MLMath.logSumExp(dArr);
        for (int i2 = 0; i2 < size; i2++) {
            dArr[i2] = Math.exp(dArr[i2] - logSumExp);
        }
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < size; i3++) {
            ProbInfo probInfo = new ProbInfo();
            probInfo.f13836a = this.f13833c.get(Integer.valueOf(i3));
            probInfo.f13837b = dArr[i3];
            arrayList.add(probInfo);
        }
        Collections.sort(arrayList);
        return arrayList;
    }

    public void load(InputStream inputStream, Set<String> set) {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
        while (true) {
            String readLine = bufferedReader.readLine();
            int i2 = 0;
            if (readLine == null) {
                bufferedReader.close();
                for (String str : set) {
                    this.f13832b.put(str, Integer.valueOf(i2));
                    this.f13833c.put(Integer.valueOf(i2), str);
                    i2++;
                }
                return;
            }
            String trim = readLine.trim();
            if (!trim.isEmpty()) {
                String[] split = trim.split(" ");
                if (split.length < 2) {
                    throw new IllegalArgumentException("feature weight not found: " + trim);
                }
                String str2 = split[0];
                LabelWeights labelWeights = new LabelWeights();
                for (int i3 = 1; i3 < split.length; i3++) {
                    String[] split2 = split[i3].split(":");
                    if (split2.length != 2) {
                        throw new IllegalArgumentException("label weight format error: " + split[i3]);
                    }
                    String str3 = split2[0];
                    if (!set.contains(str3)) {
                        throw new IllegalArgumentException("label wasn't in label set: " + str3);
                    }
                    try {
                        labelWeights.put(str3, Double.parseDouble(split2[1]));
                    } catch (NumberFormatException unused) {
                        throw new NumberFormatException("feature weight parse error: " + split2[1]);
                    }
                }
                this.f13831a.put(str2, labelWeights);
            }
        }
    }
}
