/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreProbability;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.rexp.RBooleanVector;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExp;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RExpUtil;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RNumberVector;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.S4Object;
import org.jpmml.rexp.TreeModelConverter;

public class BinaryTreeConverter
extends TreeModelConverter<S4Object> {
    private MiningFunction miningFunction = null;
    private Map<String, Integer> featureIndexes = new LinkedHashMap<String, Integer>();

    public BinaryTreeConverter(S4Object binaryTree) {
        super(binaryTree);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        S4Object binaryTree = (S4Object)this.getObject();
        S4Object responses = (S4Object)binaryTree.getAttribute("responses");
        RGenericVector tree = binaryTree.getGenericAttribute("tree");
        this.encodeResponse(responses, encoder);
        this.encodeVariableList(tree, encoder);
    }

    public TreeModel encodeModel(Schema schema) {
        Output output;
        S4Object binaryTree = (S4Object)this.getObject();
        RGenericVector tree = binaryTree.getGenericAttribute("tree");
        switch (this.miningFunction) {
            case REGRESSION: {
                output = new Output();
                break;
            }
            case CLASSIFICATION: {
                CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
                output = ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
                break;
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
        output.addOutputFields(new OutputField[]{ModelUtil.createEntityIdField((String)"nodeId", (DataType)DataType.STRING)});
        TreeModel treeModel = this.encodeTreeModel(tree, schema).setOutput(output);
        return treeModel;
    }

    private void encodeResponse(S4Object responses, RExpEncoder encoder) {
        DataField dataField;
        RGenericVector variables = responses.getGenericAttribute("variables");
        RBooleanVector is_nominal = responses.getBooleanAttribute("is_nominal");
        RGenericVector levels = responses.getGenericAttribute("levels");
        RStringVector variableNames = variables.names();
        String variableName = (String)variableNames.asScalar();
        Boolean categorical = (Boolean)is_nominal.getElement(variableName);
        if (Boolean.TRUE.equals(categorical)) {
            this.miningFunction = MiningFunction.CLASSIFICATION;
            RExp targetVariable = (RExp)variables.getElement(variableName);
            RStringVector targetVariableClass = targetVariable._class();
            RStringVector targetCategories = levels.getStringElement(variableName);
            dataField = encoder.createDataField(variableName, OpType.CATEGORICAL, RExpUtil.getDataType((String)targetVariableClass.asScalar()), targetCategories.getValues());
        } else if (Boolean.FALSE.equals(categorical)) {
            this.miningFunction = MiningFunction.REGRESSION;
            dataField = encoder.createDataField(variableName, OpType.CONTINUOUS, DataType.DOUBLE);
        } else {
            throw new IllegalArgumentException();
        }
        encoder.setLabel(dataField);
    }

    private void encodeVariableList(RGenericVector tree, RExpEncoder encoder) {
        RBooleanVector terminal = tree.getBooleanElement("terminal");
        RGenericVector psplit = tree.getGenericElement("psplit");
        RGenericVector left = tree.getGenericElement("left");
        RGenericVector right = tree.getGenericElement("right");
        if (Boolean.TRUE.equals(terminal.asScalar())) {
            return;
        }
        RNumberVector<?> splitpoint = psplit.getNumericElement("splitpoint");
        RStringVector variableName = psplit.getStringElement("variableName");
        String name = (String)variableName.asScalar();
        DataField dataField = encoder.getDataField(name);
        if (dataField == null) {
            if (splitpoint instanceof RIntegerVector) {
                RStringVector levels = splitpoint.getStringAttribute("levels");
                dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues());
            } else if (splitpoint instanceof RDoubleVector) {
                dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
            } else {
                throw new IllegalArgumentException();
            }
            encoder.addFeature((Field<?>)dataField);
            this.featureIndexes.put(name, this.featureIndexes.size());
        }
        this.encodeVariableList(left, encoder);
        this.encodeVariableList(right, encoder);
    }

    private TreeModel encodeTreeModel(RGenericVector tree, Schema schema) {
        Node root = this.encodeNode(tree, (Predicate)True.INSTANCE, schema);
        TreeModel treeModel = new TreeModel(this.miningFunction, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);
        return treeModel;
    }

    private Node encodeNode(RGenericVector tree, Predicate predicate, Schema schema) {
        Predicate rightPredicate;
        Predicate leftPredicate;
        RIntegerVector nodeId = tree.getIntegerElement("nodeID");
        RBooleanVector terminal = tree.getBooleanElement("terminal");
        RGenericVector psplit = tree.getGenericElement("psplit");
        RGenericVector ssplits = tree.getGenericElement("ssplits");
        RDoubleVector prediction = tree.getDoubleElement("prediction");
        RGenericVector left = tree.getGenericElement("left");
        RGenericVector right = tree.getGenericElement("right");
        Integer id = (Integer)nodeId.asScalar();
        if (Boolean.TRUE.equals(terminal.asScalar())) {
            SimpleNode result = new LeafNode(null, predicate).setId((Object)id);
            return this.encodeScore((Node)result, prediction, schema);
        }
        RNumberVector<?> splitpoint = psplit.getNumericElement("splitpoint");
        RStringVector variableName = psplit.getStringElement("variableName");
        if (!ssplits.isEmpty()) {
            throw new IllegalArgumentException();
        }
        String name = (String)variableName.asScalar();
        Integer index = this.featureIndexes.get(name);
        if (index == null) {
            throw new IllegalArgumentException();
        }
        Feature feature = schema.getFeature(index.intValue());
        if (feature instanceof CategoricalFeature) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
            List values = categoricalFeature.getValues();
            List<Integer> splitValues = splitpoint.getValues();
            leftPredicate = this.createPredicate((Feature)categoricalFeature, BinaryTreeConverter.selectValues(values, splitValues, true));
            rightPredicate = this.createPredicate((Feature)categoricalFeature, BinaryTreeConverter.selectValues(values, splitValues, false));
        } else {
            ContinuousFeature continuousFeature = feature.toContinuousFeature();
            Number value = (Number)splitpoint.asScalar();
            leftPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
            rightPredicate = this.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
        }
        Node leftChild = this.encodeNode(left, leftPredicate, schema);
        Node rightChild = this.encodeNode(right, rightPredicate, schema);
        Node result = new BranchNode(null, predicate).setId((Object)id).addNodes(leftChild, rightChild);
        return result;
    }

    private Node encodeScore(Node node, RDoubleVector probabilities, Schema schema) {
        switch (this.miningFunction) {
            case REGRESSION: {
                return BinaryTreeConverter.encodeRegressionScore(node, probabilities);
            }
            case CLASSIFICATION: {
                return BinaryTreeConverter.encodeClassificationScore(node, probabilities, schema);
            }
        }
        throw new IllegalArgumentException();
    }

    private static <E> List<E> selectValues(List<E> values, List<Integer> splits, boolean left) {
        if (values.size() != splits.size()) {
            throw new IllegalArgumentException();
        }
        ArrayList<E> result = new ArrayList<E>();
        for (int i = 0; i < values.size(); ++i) {
            boolean append;
            E value = values.get(i);
            Integer split = splits.get(i);
            if (left) {
                append = split == 1;
            } else {
                boolean bl = append = split == 0;
            }
            if (!append) continue;
            result.add(value);
        }
        return result;
    }

    private static Node encodeRegressionScore(Node node, RDoubleVector probabilities) {
        Double probability = (Double)probabilities.asScalar();
        node.setScore((Object)probability);
        return node;
    }

    private static Node encodeClassificationScore(Node node, RDoubleVector probabilities, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        SchemaUtil.checkSize((int)probabilities.size(), (DiscreteLabel)categoricalLabel);
        node = new ClassifierNode(node);
        List scoreDistributions = node.getScoreDistributions();
        Double maxProbability = null;
        for (int i = 0; i < categoricalLabel.size(); ++i) {
            Object value = categoricalLabel.getValue(i);
            Double probability = probabilities.getValue(i);
            if (maxProbability == null || maxProbability.compareTo(probability) < 0) {
                node.setScore(value);
                maxProbability = probability;
            }
            ScoreProbability scoreDistribution = new ScoreProbability(value, null, (Number)probability);
            scoreDistributions.add(scoreDistribution);
        }
        return node;
    }
}

