/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.analyses;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.FieldCardinalityConstraint;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.MapUtils;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.RequiredField;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Types;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class Classification
implements DataFrameAnalysis {
    public static final ParseField NAME = new ParseField("classification", new String[0]);
    public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable", new String[0]);
    public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name", new String[0]);
    public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective", new String[0]);
    public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes", new String[0]);
    public static final ParseField TRAINING_PERCENT = new ParseField("training_percent", new String[0]);
    public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed", new String[0]);
    private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
    private static final String NUM_CLASSES = "num_classes";
    private static final ConstructingObjectParser<Classification, Void> LENIENT_PARSER = Classification.createParser(true);
    private static final ConstructingObjectParser<Classification, Void> STRICT_PARSER = Classification.createParser(false);
    public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
    private static final Set<String> ALLOWED_DEPENDENT_VARIABLE_TYPES = Collections.unmodifiableSet(Stream.of(Types.categorical(), Types.discreteNumerical(), Types.bool()).flatMap(Collection::stream).collect(Collectors.toSet()));
    private static final String PREDICTION_FIELD_TYPE = "prediction_field_type";
    private static final int DEFAULT_NUM_TOP_CLASSES = 2;
    private static final List<String> PROGRESS_PHASES = Collections.unmodifiableList(Arrays.asList("feature_selection", "coarse_parameter_search", "fine_tuning_parameters", "final_training"));
    private final String dependentVariable;
    private final BoostedTreeParams boostedTreeParams;
    private final String predictionFieldName;
    private final ClassAssignmentObjective classAssignmentObjective;
    private final int numTopClasses;
    private final double trainingPercent;
    private final long randomizeSeed;

    private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
        ConstructingObjectParser parser = new ConstructingObjectParser(NAME.getPreferredName(), lenient, a -> new Classification((String)a[0], new BoostedTreeParams((Double)a[1], (Double)a[2], (Double)a[3], (Integer)a[4], (Double)a[5], (Integer)a[6]), (String)a[7], (ClassAssignmentObjective)((Object)((Object)a[8])), (Integer)a[9], (Double)a[10], (Long)a[11]));
        parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
        BoostedTreeParams.declareFields(parser);
        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
        parser.declareString(ConstructingObjectParser.optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
        parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
        parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
        parser.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
        return parser;
    }

    public static Classification fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
        return ignoreUnknownFields ? (Classification)LENIENT_PARSER.apply(parser, null) : (Classification)STRICT_PARSER.apply(parser, null);
    }

    public Classification(String dependentVariable, BoostedTreeParams boostedTreeParams, @Nullable String predictionFieldName, @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable Integer numTopClasses, @Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
        if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
            throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
        }
        if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
            throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
        }
        this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE);
        this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, "boosted_tree_params");
        this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName;
        this.classAssignmentObjective = classAssignmentObjective == null ? ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL : classAssignmentObjective;
        this.numTopClasses = numTopClasses == null ? 2 : numTopClasses;
        this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
        this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed.longValue();
    }

    public Classification(String dependentVariable) {
        this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
    }

    public Classification(StreamInput in) throws IOException {
        this.dependentVariable = in.readString();
        this.boostedTreeParams = new BoostedTreeParams(in);
        this.predictionFieldName = in.readOptionalString();
        this.classAssignmentObjective = in.getVersion().onOrAfter(Version.V_7_7_0) ? (ClassAssignmentObjective)in.readEnum(ClassAssignmentObjective.class) : ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL;
        this.numTopClasses = in.readOptionalVInt();
        this.trainingPercent = in.readDouble();
        this.randomizeSeed = in.getVersion().onOrAfter(Version.V_7_6_0) ? in.readOptionalLong().longValue() : Randomness.get().nextLong();
    }

    public String getDependentVariable() {
        return this.dependentVariable;
    }

    public BoostedTreeParams getBoostedTreeParams() {
        return this.boostedTreeParams;
    }

    public String getPredictionFieldName() {
        return this.predictionFieldName;
    }

    public ClassAssignmentObjective getClassAssignmentObjective() {
        return this.classAssignmentObjective;
    }

    public int getNumTopClasses() {
        return this.numTopClasses;
    }

    public double getTrainingPercent() {
        return this.trainingPercent;
    }

    public long getRandomizeSeed() {
        return this.randomizeSeed;
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeString(this.dependentVariable);
        this.boostedTreeParams.writeTo(out);
        out.writeOptionalString(this.predictionFieldName);
        if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
            out.writeEnum((Enum)this.classAssignmentObjective);
        }
        out.writeOptionalVInt(Integer.valueOf(this.numTopClasses));
        out.writeDouble(this.trainingPercent);
        if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
            out.writeOptionalLong(Long.valueOf(this.randomizeSeed));
        }
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        Version version = Version.fromString((String)params.param("version", Version.CURRENT.toString()));
        builder.startObject();
        builder.field(DEPENDENT_VARIABLE.getPreferredName(), this.dependentVariable);
        this.boostedTreeParams.toXContent(builder, params);
        builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), (Object)this.classAssignmentObjective);
        builder.field(NUM_TOP_CLASSES.getPreferredName(), this.numTopClasses);
        if (this.predictionFieldName != null) {
            builder.field(PREDICTION_FIELD_NAME.getPreferredName(), this.predictionFieldName);
        }
        builder.field(TRAINING_PERCENT.getPreferredName(), this.trainingPercent);
        if (version.onOrAfter(Version.V_7_6_0)) {
            builder.field(RANDOMIZE_SEED.getPreferredName(), this.randomizeSeed);
        }
        builder.endObject();
        return builder;
    }

    @Override
    public Map<String, Object> getParams(DataFrameAnalysis.FieldInfo fieldInfo) {
        String predictionFieldType;
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put(DEPENDENT_VARIABLE.getPreferredName(), this.dependentVariable);
        params.putAll(this.boostedTreeParams.getParams());
        params.put(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), (Object)this.classAssignmentObjective);
        params.put(NUM_TOP_CLASSES.getPreferredName(), this.numTopClasses);
        if (this.predictionFieldName != null) {
            params.put(PREDICTION_FIELD_NAME.getPreferredName(), this.predictionFieldName);
        }
        if ((predictionFieldType = Classification.getPredictionFieldTypeParamString(Classification.getPredictionFieldType(fieldInfo.getTypes(this.dependentVariable)))) != null) {
            params.put(PREDICTION_FIELD_TYPE, predictionFieldType);
        }
        params.put(NUM_CLASSES, fieldInfo.getCardinality(this.dependentVariable));
        params.put(TRAINING_PERCENT.getPreferredName(), this.trainingPercent);
        return params;
    }

    private static String getPredictionFieldTypeParamString(PredictionFieldType predictionFieldType) {
        if (predictionFieldType == null) {
            return null;
        }
        switch (predictionFieldType) {
            case NUMBER: {
                return "int";
            }
            case STRING: {
                return "string";
            }
            case BOOLEAN: {
                return "bool";
            }
        }
        return null;
    }

    public static PredictionFieldType getPredictionFieldType(Set<String> dependentVariableTypes) {
        if (dependentVariableTypes == null) {
            return null;
        }
        if (Types.categorical().containsAll(dependentVariableTypes)) {
            return PredictionFieldType.STRING;
        }
        if (Types.bool().containsAll(dependentVariableTypes)) {
            return PredictionFieldType.BOOLEAN;
        }
        if (Types.discreteNumerical().containsAll(dependentVariableTypes)) {
            return PredictionFieldType.NUMBER;
        }
        return null;
    }

    @Override
    public boolean supportsCategoricalFields() {
        return true;
    }

    @Override
    public Set<String> getAllowedCategoricalTypes(String fieldName) {
        if (this.dependentVariable.equals(fieldName)) {
            return ALLOWED_DEPENDENT_VARIABLE_TYPES;
        }
        return Types.categorical();
    }

    @Override
    public List<RequiredField> getRequiredFields() {
        return Collections.singletonList(new RequiredField(this.dependentVariable, ALLOWED_DEPENDENT_VARIABLE_TYPES));
    }

    @Override
    public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
        return Collections.singletonList(FieldCardinalityConstraint.between(this.dependentVariable, 2L, 30L));
    }

    @Override
    public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
        HashMap<String, Object> additionalProperties = new HashMap<String, Object>();
        additionalProperties.put(resultsFieldName + ".feature_importance", MapUtils.featureImportanceMapping());
        Object dependentVariableMapping = Classification.extractMapping(this.dependentVariable, mappingsProperties);
        if (!(dependentVariableMapping instanceof Map)) {
            return additionalProperties;
        }
        Map dependentVariableMappingAsMap = (Map)dependentVariableMapping;
        if ("alias".equals(dependentVariableMappingAsMap.get("type"))) {
            String path = (String)dependentVariableMappingAsMap.get("path");
            dependentVariableMapping = Classification.extractMapping(path, mappingsProperties);
        }
        if (!(dependentVariableMapping instanceof Map)) {
            return additionalProperties;
        }
        additionalProperties.put(resultsFieldName + "." + this.predictionFieldName, dependentVariableMapping);
        additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
        return additionalProperties;
    }

    private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
        return XContentMapValues.extractValue((String)String.join((CharSequence)".properties.", path.split("\\.")), mappingsProperties);
    }

    @Override
    public boolean supportsMissingValues() {
        return true;
    }

    @Override
    public boolean persistsState() {
        return true;
    }

    @Override
    public String getStateDocId(String jobId) {
        return jobId + STATE_DOC_ID_SUFFIX;
    }

    @Override
    public List<String> getProgressPhases() {
        return PROGRESS_PHASES;
    }

    @Override
    public InferenceConfig inferenceConfig(DataFrameAnalysis.FieldInfo fieldInfo) {
        PredictionFieldType predictionFieldType = Classification.getPredictionFieldType(fieldInfo.getTypes(this.dependentVariable));
        return ClassificationConfig.builder().setResultsField(this.predictionFieldName).setNumTopClasses(this.numTopClasses).setNumTopFeatureImportanceValues(this.getBoostedTreeParams().getNumTopFeatureImportanceValues()).setPredictionFieldType(predictionFieldType).build();
    }

    @Override
    public boolean supportsInference() {
        return true;
    }

    public static String extractJobIdFromStateDoc(String stateDocId) {
        int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
        return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Classification that = (Classification)o;
        return Objects.equals(this.dependentVariable, that.dependentVariable) && Objects.equals(this.boostedTreeParams, that.boostedTreeParams) && Objects.equals(this.predictionFieldName, that.predictionFieldName) && Objects.equals((Object)this.classAssignmentObjective, (Object)that.classAssignmentObjective) && Objects.equals(this.numTopClasses, that.numTopClasses) && this.trainingPercent == that.trainingPercent && this.randomizeSeed == that.randomizeSeed;
    }

    public int hashCode() {
        return Objects.hash(new Object[]{this.dependentVariable, this.boostedTreeParams, this.predictionFieldName, this.classAssignmentObjective, this.numTopClasses, this.trainingPercent, this.randomizeSeed});
    }

    public static enum ClassAssignmentObjective {
        MAXIMIZE_ACCURACY,
        MAXIMIZE_MINIMUM_RECALL;


        public static ClassAssignmentObjective fromString(String value) {
            return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
        }

        public String toString() {
            return this.name().toLowerCase(Locale.ROOT);
        }
    }
}

