package com.lombardisoftware.ai.dtree;

import com.lombardisoftware.ai.dtree.input.DataPoint;
import com.lombardisoftware.ai.dtree.input.DataSet;
import com.lombardisoftware.logger.WLELoggerConstants;
import com.lombardisoftware.utility.EqualityUtils;
import com.lombardisoftware.utility.collections.CollectionBuilder;
import com.lombardisoftware.utility.collections.ValueCountMap;
import com.lombardisoftware.utility.comparators.NaturalComparator;
import com.lombardisoftware.utility.comparators.NullAwareDelegatingComparator;
import com.lombardisoftware.utility.functions.UnaryFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.springframework.asm.Opcodes;

/* loaded from: input_file:lib/utility.jar:com/lombardisoftware/ai/dtree/DecisionTreeBuilder.class */
public class DecisionTreeBuilder {
    public static final float DEFAULT_TRAINING_RATIO = 0.632f;
    public static final int MAX_CHILD_NODES = 10;
    public static final int DEFAULT_DEPTH_LIMIT = 20;
    private static Logger logger = Logger.getLogger(WLELoggerConstants.WLE_LOGGER, WLELoggerConstants.WLE_PIIMESSAGE_FILE);
    private static final String CLASS_NAME = DecisionTreeBuilder.class.getName();
    private static final Comparator nullAwareNaturalComparator = new NullAwareDelegatingComparator(new NaturalComparator());
    private static final Object NO_ANSWER_GIVEN = new Object();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/utility.jar:com/lombardisoftware/ai/dtree/DecisionTreeBuilder$DataPointPropertyGetter.class */
    public static class DataPointPropertyGetter implements UnaryFunction {
        private final String splitPropertyName;

        DataPointPropertyGetter(String str) {
            this.splitPropertyName = str;
        }

        @Override // com.lombardisoftware.utility.functions.UnaryFunction
        public Object execute(Object obj) {
            return ((DataPoint) obj).getPropertyValue(this.splitPropertyName);
        }
    }

    /* loaded from: input_file:lib/utility.jar:com/lombardisoftware/ai/dtree/DecisionTreeBuilder$PropertyValueComparator.class */
    public static class PropertyValueComparator implements Comparator {
        private String propertyName;
        private Comparator delegate;

        public PropertyValueComparator(String str, Comparator comparator) {
            this.propertyName = str;
            this.delegate = comparator;
        }

        @Override // java.util.Comparator
        public int compare(Object obj, Object obj2) {
            return this.delegate.compare(((DataPoint) obj).getPropertyValue(this.propertyName), ((DataPoint) obj2).getPropertyValue(this.propertyName));
        }
    }

    private DecisionTreeBuilder() {
    }

    public static DtreeNode buildDecisionTree(DataSet dataSet, String str) {
        return buildDecisionTree(dataSet, str, 0.632f, 20, new TreeSet(dataSet.getPropertyNames()));
    }

    public static DtreeNode buildDecisionTree(DataSet dataSet, String str, float f, int i, Set set) {
        Random random = new Random(0L);
        List<DataPoint> dataPoints = dataSet.getDataPoints();
        if (dataPoints == null || dataPoints.isEmpty()) {
            logger.log(Level.INFO, "ai.dtree.input.DataPoint.emptyDataPointsList");
            return null;
        }
        ArrayList arrayList = new ArrayList((int) (dataPoints.size() * (f + 0.1f)));
        ArrayList arrayList2 = new ArrayList((int) (dataPoints.size() * (1.1f - f)));
        for (DataPoint dataPoint : dataPoints) {
            if (random.nextDouble() < f) {
                arrayList.add(dataPoint);
            } else {
                arrayList2.add(dataPoint);
            }
        }
        if (arrayList.isEmpty() || arrayList2.isEmpty()) {
            logger.log(Level.INFO, "ai.dtree.input.DataPoint.emptyTrainingTestingSets", new Object[]{arrayList, arrayList2});
            return null;
        }
        DtreeNode buildDecisionTree = buildDecisionTree(dataSet, arrayList, str, set, i);
        if (logger.isLoggable(Level.FINE)) {
            logger.logp(Level.FINE, CLASS_NAME, "buildDecisionTree", "Pre-pruned tree:");
            StringBuffer stringBuffer = new StringBuffer(Opcodes.ACC_ABSTRACT);
            buildDecisionTree.dumpRules(stringBuffer);
            logger.logp(Level.FINE, CLASS_NAME, "buildDecisionTree", "\n" + stringBuffer.toString());
        }
        DtreeNode pruneDecisionTree = pruneDecisionTree(buildDecisionTree, str, arrayList2);
        if (pruneDecisionTree != null) {
            buildDecisionTree = pruneDecisionTree;
        }
        setTreeScores(buildDecisionTree, dataPoints, str);
        if (logger.isLoggable(Level.FINE)) {
            logger.logp(Level.FINE, CLASS_NAME, "buildDecisionTree", "After pruning and scoring:");
            StringBuffer stringBuffer2 = new StringBuffer(Opcodes.ACC_ABSTRACT);
            buildDecisionTree.dumpRules(stringBuffer2);
            logger.logp(Level.FINE, CLASS_NAME, "buildDecisionTree", "\n" + stringBuffer2.toString());
        }
        return buildDecisionTree;
    }

    private static DtreeNode buildDecisionTree(DataSet dataSet, List list, String str, Set set, int i) {
        TreeSet treeSet = new TreeSet(set);
        if (i <= 0) {
            return buildLeafNode(list, str);
        }
        DtreeNode buildNode = buildNode(dataSet, list, str, treeSet);
        if (!(buildNode instanceof LeafNode)) {
            Map partitionDataPoints = buildNode.partitionDataPoints(list);
            if (buildNode instanceof SimpleSplitNode) {
                treeSet.remove(buildNode.getSplitVariable());
            }
            for (DtreeNode dtreeNode : new ArrayList(buildNode.getChildNodes())) {
                if (dtreeNode.getNumCorrectDataPoints() < dtreeNode.getNumDataPoints()) {
                    DtreeNode buildDecisionTree = buildDecisionTree(dataSet, (List) partitionDataPoints.get(dtreeNode), str, treeSet, i - 1);
                    if (buildDecisionTree.getNumCorrectDataPoints() > dtreeNode.getNumCorrectDataPoints()) {
                        buildNode.replaceChildNode(dtreeNode, buildDecisionTree);
                    }
                }
            }
        }
        return buildNode;
    }

    private static DtreeNode buildNode(DataSet dataSet, List list, String str, Set set) {
        DtreeNode buildLeafNode = buildLeafNode(list, str);
        double score = buildLeafNode.getScore();
        Iterator it = set.iterator();
        while (it.hasNext()) {
            String str2 = (String) it.next();
            DataSet.PropertyDescriptor propertyDescriptor = dataSet.getPropertyDescriptor(str2);
            if (propertyDescriptor == null) {
                logger.log(Level.WARNING, "ai.dtree.input.DataPoint.noPropertyDescriptor", new Object[]{str2});
            } else {
                DtreeNode buildContinuousSplitNode = (propertyDescriptor.isContinuous() || propertyDescriptor.getPossibleValues().size() > 10) ? buildContinuousSplitNode(list, str2, str) : buildSimpleSplitNode(list, str2, str);
                if (buildContinuousSplitNode != null) {
                    double score2 = buildContinuousSplitNode.getScore();
                    if (score2 > score) {
                        score = score2;
                        buildLeafNode = buildContinuousSplitNode;
                    }
                }
            }
        }
        if (logger.isLoggable(Level.FINE)) {
            logger.logp(Level.FINE, CLASS_NAME, "buildNode", "buildNode: selecting split on " + buildLeafNode.getSplitVariable() + " (" + score + ")");
        }
        return buildLeafNode;
    }

    private static ContinuousSplitNode buildContinuousSplitNode(List list, String str, String str2) {
        Collections.sort(list, new PropertyValueComparator(str, nullAwareNaturalComparator));
        ValueCountMap valueCountMap = new ValueCountMap(new TreeMap(nullAwareNaturalComparator));
        ValueCountMap valueCountMap2 = new ValueCountMap(new TreeMap(nullAwareNaturalComparator));
        countValues(list, str2, valueCountMap2);
        Object obj = null;
        double d = -1.0d;
        ContinuousSplitNode continuousSplitNode = null;
        Object obj2 = null;
        int size = list.size();
        for (int i = 0; i < size; i++) {
            DataPoint dataPoint = (DataPoint) list.get(i);
            ContinuousSplitNode continuousSplitNode2 = new ContinuousSplitNode();
            Object propertyValue = dataPoint.getPropertyValue(str2);
            Object propertyValue2 = dataPoint.getPropertyValue(str);
            if (propertyValue2 != null && (obj2 == null || ((Comparable) propertyValue2).compareTo(obj2) != 0)) {
                double computeScore = computeScore(valueCountMap.get(valueCountMap.getMaximumKey()), valueCountMap.getTotal(), valueCountMap2.get(valueCountMap2.getMaximumKey()), valueCountMap2.getTotal());
                if (computeScore > d) {
                    d = computeScore;
                    obj = propertyValue2;
                    continuousSplitNode = continuousSplitNode2;
                    LeafNode buildLeafNode = buildLeafNode(list.subList(0, i), str2);
                    continuousSplitNode2.setLessChild(buildLeafNode);
                    LeafNode buildLeafNode2 = buildLeafNode(list.subList(i, list.size()), str2);
                    continuousSplitNode2.setGreaterOrEqualChild(buildLeafNode2);
                    continuousSplitNode2.setNumCorrectDataPoints(buildLeafNode.getNumCorrectDataPoints() + buildLeafNode2.getNumCorrectDataPoints());
                    continuousSplitNode2.setNumDataPoints(list.size());
                }
            }
            obj2 = propertyValue2;
            valueCountMap.increment(propertyValue);
            valueCountMap2.decrement(propertyValue);
        }
        if (continuousSplitNode != null) {
            continuousSplitNode.setSplitVariable(str);
            continuousSplitNode.setSplitValue(obj);
            continuousSplitNode.setScore(d);
        }
        return continuousSplitNode;
    }

    private static SimpleSplitNode buildSimpleSplitNode(List list, String str, String str2) {
        SimpleSplitNode simpleSplitNode = new SimpleSplitNode();
        simpleSplitNode.setSplitVariable(str);
        TreeMap treeMap = new TreeMap(nullAwareNaturalComparator);
        CollectionBuilder.buildMapOfLists(list, new DataPointPropertyGetter(str), treeMap);
        int i = 0;
        for (Map.Entry entry : treeMap.entrySet()) {
            Object key = entry.getKey();
            LeafNode buildLeafNode = buildLeafNode((List) entry.getValue(), str2);
            simpleSplitNode.addChild(key, buildLeafNode);
            i += buildLeafNode.getNumCorrectDataPoints();
        }
        simpleSplitNode.setNumDataPoints(list.size());
        simpleSplitNode.setNumCorrectDataPoints(i);
        simpleSplitNode.setScore(scoreSimpleSplitNode(simpleSplitNode));
        return simpleSplitNode;
    }

    private static double scoreSimpleSplitNode(SimpleSplitNode simpleSplitNode) {
        List childNodes = simpleSplitNode.getChildNodes();
        int size = childNodes.size();
        int[] iArr = new int[size];
        int[] iArr2 = new int[size];
        for (int i = 0; i < size; i++) {
            DtreeNode dtreeNode = (DtreeNode) childNodes.get(i);
            iArr[i] = dtreeNode.getNumCorrectDataPoints();
            iArr2[i] = dtreeNode.getNumDataPoints();
        }
        return computeScore(iArr, iArr2);
    }

    private static double computeScore(int i, int i2) {
        return computeScore(new int[]{i}, new int[]{i2});
    }

    private static double computeScore(int i, int i2, int i3, int i4) {
        return computeScore(new int[]{i, i3}, new int[]{i2, i4});
    }

    private static double computeScore(int[] iArr, int[] iArr2) {
        return computeScoreGini(iArr, iArr2);
    }

    private static double computeScoreGini(int[] iArr, int[] iArr2) {
        double d = 0.0d;
        int length = iArr.length;
        for (int i = 0; i < length; i++) {
            int i2 = iArr[i];
            int i3 = iArr2[i];
            double d2 = i2 / i3;
            d += d2 * d2 * i3;
        }
        return d;
    }

    private static void countValues(List list, String str, ValueCountMap valueCountMap) {
        Iterator it = list.iterator();
        while (it.hasNext()) {
            valueCountMap.increment(((DataPoint) it.next()).getPropertyValue(str));
        }
    }

    private static LeafNode buildLeafNode(List list, String str) {
        ValueCountMap valueCountMap = new ValueCountMap(new TreeMap(nullAwareNaturalComparator));
        countValues(list, str, valueCountMap);
        Object maximumKey = valueCountMap.getMaximumKey();
        LeafNode leafNode = new LeafNode(maximumKey);
        leafNode.setNumDataPoints(list.size());
        leafNode.setNumCorrectDataPoints(valueCountMap.get(maximumKey));
        leafNode.setScore(computeScore(leafNode.getNumCorrectDataPoints(), leafNode.getNumDataPoints()));
        return leafNode;
    }

    private static DtreeNode pruneDecisionTree(DtreeNode dtreeNode, String str, List list) {
        return pruneDecisionTree(dtreeNode, str, list, Collections.EMPTY_LIST);
    }

    private static DtreeNode pruneDecisionTree(final DtreeNode dtreeNode, String str, List list, List list2) {
        DtreeNode pruneDecisionTree;
        if (dtreeNode instanceof LeafNode) {
            return null;
        }
        ArrayList arrayList = new ArrayList(list2.size() + 1);
        arrayList.addAll(list2);
        arrayList.add(dtreeNode);
        HashMap hashMap = new HashMap((dtreeNode.getChildNodes().size() * 2) + 1);
        CollectionBuilder.buildMapOfLists(list, new UnaryFunction() { // from class: com.lombardisoftware.ai.dtree.DecisionTreeBuilder.1
            @Override // com.lombardisoftware.utility.functions.UnaryFunction
            public Object execute(Object obj) {
                return DtreeNode.this.getChildNode((DataPoint) obj);
            }
        }, hashMap);
        ArrayList<DtreeNode> arrayList2 = new ArrayList(dtreeNode.getChildNodes());
        for (DtreeNode dtreeNode2 : arrayList2) {
            List list3 = (List) hashMap.get(dtreeNode2);
            if (list3 != null && (pruneDecisionTree = pruneDecisionTree(dtreeNode2, str, list3, arrayList)) != null) {
                dtreeNode.replaceChildNode(dtreeNode2, pruneDecisionTree);
            }
        }
        ValueCountMap valueCountMap = new ValueCountMap(new TreeMap(nullAwareNaturalComparator));
        countValues(list, str, valueCountMap);
        int i = valueCountMap.get(valueCountMap.getMaximumKey());
        int i2 = 0;
        Iterator it = list.iterator();
        while (it.hasNext()) {
            DataPoint dataPoint = (DataPoint) it.next();
            Object expectedAnswer = dtreeNode.getExpectedAnswer(dataPoint);
            if (expectedAnswer != null && expectedAnswer.equals(dataPoint.getPropertyValue(str))) {
                i2++;
            }
        }
        if (i >= i2) {
            boolean z = true;
            double size = i / list.size();
            Iterator it2 = arrayList2.iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                DtreeNode dtreeNode3 = (DtreeNode) it2.next();
                if (dtreeNode3.getNumCorrectDataPoints() / dtreeNode3.getNumDataPoints() > size) {
                    z = false;
                    break;
                }
            }
            if (z) {
                return buildLeafNode(list, str);
            }
        }
        dtreeNode.setNumCorrectDataPoints(i2);
        dtreeNode.setNumDataPoints(list.size());
        return null;
    }

    private static void setTreeScores(DtreeNode dtreeNode, List list, String str) {
        clearTreeScores(dtreeNode);
        Iterator it = list.iterator();
        while (it.hasNext()) {
            DataPoint dataPoint = (DataPoint) it.next();
            boolean objectsEqual = EqualityUtils.objectsEqual(getExactAnswer(dtreeNode, dataPoint), dataPoint.getPropertyValue(str));
            DtreeNode dtreeNode2 = dtreeNode;
            while (true) {
                DtreeNode dtreeNode3 = dtreeNode2;
                if (dtreeNode3 != null) {
                    dtreeNode3.setNumDataPoints(dtreeNode3.getNumDataPoints() + 1);
                    if (objectsEqual) {
                        dtreeNode3.setNumCorrectDataPoints(dtreeNode3.getNumCorrectDataPoints() + 1);
                    }
                    dtreeNode2 = dtreeNode3.getChildNode(dataPoint);
                }
            }
        }
    }

    private static void clearTreeScores(DtreeNode dtreeNode) {
        dtreeNode.setNumCorrectDataPoints(0);
        dtreeNode.setNumDataPoints(0);
        dtreeNode.setScore(0.0d);
        Iterator it = dtreeNode.getChildNodes().iterator();
        while (it.hasNext()) {
            clearTreeScores((DtreeNode) it.next());
        }
    }

    private static Object getExactAnswer(DtreeNode dtreeNode, DataPoint dataPoint) {
        if (dtreeNode instanceof LeafNode) {
            return ((LeafNode) dtreeNode).getAnswer();
        }
        DtreeNode childNode = dtreeNode.getChildNode(dataPoint);
        return childNode == null ? NO_ANSWER_GIVEN : getExactAnswer(childNode, dataPoint);
    }
}
