/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.anomaly.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.DataSource;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.anomaly.AnomalyFactory;
import org.tribuo.anomaly.Event;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;

public final class GaussianAnomalyDataSource
implements ConfigurableDataSource<Event> {
    private static final AnomalyFactory factory = new AnomalyFactory();
    private static final String[] allFeatureNames = new String[]{"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"};
    @Config(mandatory=true, description="The number of samples to draw.")
    private int numSamples;
    @Config(description="Means of the expected events.")
    private double[] expectedMeans = new double[]{1.0, 2.0, 1.0, 2.0, 5.0};
    @Config(description="Variances of the expected events.")
    private double[] expectedVariances = new double[]{1.0, 0.5, 0.25, 1.0, 0.1};
    @Config(description="Means of the anomalous events.")
    private double[] anomalousMeans = new double[]{-2.0, 2.0, -2.0, 2.0, -10.0};
    @Config(description="Variances of the anomalous events.")
    private double[] anomalousVariances = new double[]{1.0, 0.5, 0.25, 1.0, 0.1};
    @Config(description="The RNG seed.")
    private long seed = 12345L;
    @Config(mandatory=true, description="The fraction of anomalous events.")
    private float fractionAnomalous = 0.3f;
    private List<Example<Event>> examples;

    private GaussianAnomalyDataSource() {
    }

    public GaussianAnomalyDataSource(int numSamples, float fractionAnomalous, long seed) {
        this.numSamples = numSamples;
        this.fractionAnomalous = fractionAnomalous;
        this.seed = seed;
        this.postConfig();
    }

    public GaussianAnomalyDataSource(int numSamples, double[] expectedMeans, double[] expectedVariances, double[] anomalousMeans, double[] anomalousVariances, float fractionAnomalous, long seed) {
        this.numSamples = numSamples;
        this.expectedMeans = expectedMeans;
        this.expectedVariances = expectedVariances;
        this.anomalousMeans = anomalousMeans;
        this.anomalousVariances = anomalousVariances;
        this.fractionAnomalous = fractionAnomalous;
        this.seed = seed;
        this.postConfig();
    }

    public void postConfig() {
        if (this.numSamples < 1) {
            throw new PropertyException("", "numSamples", "numSamples must be positive, found " + this.numSamples);
        }
        if (this.expectedMeans.length > allFeatureNames.length || this.expectedMeans.length == 0) {
            throw new PropertyException("", "expectedMeans", "Must have 1-26 features, found " + this.expectedMeans.length);
        }
        if (this.expectedMeans.length != this.expectedVariances.length) {
            throw new PropertyException("", "expectedMeans", "Must supply the same number of expected means and variances. expectedMeans.length = " + this.expectedMeans.length + " expectedVariances.length = " + this.expectedVariances.length);
        }
        if (this.anomalousMeans.length != this.anomalousVariances.length) {
            throw new PropertyException("", "anomalousMeans", "Must supply the same number of anomalous means and variances. anomalousMeans.length = " + this.anomalousMeans.length + " anomalousVariances.length = " + this.anomalousVariances.length);
        }
        if (this.fractionAnomalous < 0.0f || this.fractionAnomalous > 1.0f) {
            throw new PropertyException("", "fractionAnomalous", "fractionAnomalous must be between 0.0 and 1.0, found " + this.fractionAnomalous);
        }
        if ((double)this.fractionAnomalous != 0.0 && this.anomalousMeans.length != this.expectedMeans.length) {
            throw new PropertyException("", "anomalousMeans", "When sampling anomalous data there must be the same number of anomalous features as expected features. anomalousMeans.length = " + this.anomalousMeans.length + ", expectedMeans.length = " + this.expectedMeans.length);
        }
        for (int i = 0; i < this.anomalousVariances.length; ++i) {
            if (this.anomalousVariances[i] < 1.0E-10) {
                throw new PropertyException("", "anomalousVariances", "Variances must be positive, found " + Arrays.toString(this.anomalousVariances));
            }
            if (!(this.expectedVariances[i] < 1.0E-10)) continue;
            throw new PropertyException("", "expectedVariances", "Variances must be positive, found " + Arrays.toString(this.expectedVariances));
        }
        String[] featureNames = Arrays.copyOf(allFeatureNames, this.expectedMeans.length);
        Random rng = new Random(this.seed);
        ArrayList<ArrayExample> examples = new ArrayList<ArrayExample>(this.numSamples);
        for (int i = 0; i < this.numSamples; ++i) {
            List<Feature> featureList;
            double draw = rng.nextDouble();
            if (draw < (double)this.fractionAnomalous) {
                featureList = GaussianAnomalyDataSource.generateFeatures(rng, featureNames, this.anomalousMeans, this.anomalousVariances);
                examples.add(new ArrayExample((Output)AnomalyFactory.ANOMALOUS_EVENT, featureList));
                continue;
            }
            featureList = GaussianAnomalyDataSource.generateFeatures(rng, featureNames, this.expectedMeans, this.expectedVariances);
            examples.add(new ArrayExample((Output)AnomalyFactory.EXPECTED_EVENT, featureList));
        }
        this.examples = Collections.unmodifiableList(examples);
    }

    public OutputFactory<Event> getOutputFactory() {
        return factory;
    }

    public DataSourceProvenance getProvenance() {
        return new GaussianAnomalyDataSourceProvenance(this);
    }

    public Iterator<Example<Event>> iterator() {
        return this.examples.iterator();
    }

    private static List<Feature> generateFeatures(Random rng, String[] names, double[] means, double[] variances) {
        if (names.length != means.length || names.length != variances.length) {
            throw new IllegalArgumentException("Names, means and variances must be the same length");
        }
        ArrayList<Feature> features = new ArrayList<Feature>();
        for (int i = 0; i < names.length; ++i) {
            double value = rng.nextGaussian() * Math.sqrt(variances[i]) + means[i];
            features.add(new Feature(names[i], value));
        }
        return features;
    }

    public static MutableDataset<Event> generateDataset(int numSamples, double[] expectedMeans, double[] expectedVariances, double[] anomalousMeans, double[] anomalousVariances, float fractionAnomalous, long seed) {
        GaussianAnomalyDataSource source = new GaussianAnomalyDataSource(numSamples, expectedMeans, expectedVariances, anomalousMeans, anomalousVariances, fractionAnomalous, seed);
        return new MutableDataset((DataSource)source);
    }

    public static final class GaussianAnomalyDataSourceProvenance
    extends SkeletalConfiguredObjectProvenance
    implements ConfiguredDataSourceProvenance {
        private static final long serialVersionUID = 1L;

        GaussianAnomalyDataSourceProvenance(GaussianAnomalyDataSource host) {
            super((Configurable)host, "DataSource");
        }

        public GaussianAnomalyDataSourceProvenance(Map<String, Provenance> map) {
            this(GaussianAnomalyDataSourceProvenance.extractProvenanceInfo(map));
        }

        private GaussianAnomalyDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo info) {
            super(info);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap<String, Provenance> configuredParameters = new HashMap<String, Provenance>(map);
            String className = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(configuredParameters, (String)"class-name", StringProvenance.class, (String)GaussianAnomalyDataSourceProvenance.class.getSimpleName())).getValue();
            String hostTypeStringName = ((StringProvenance)ObjectProvenance.checkAndExtractProvenance(configuredParameters, (String)"host-short-name", StringProvenance.class, (String)GaussianAnomalyDataSourceProvenance.class.getSimpleName())).getValue();
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(className, hostTypeStringName, configuredParameters, Collections.emptyMap());
        }
    }
}

