/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.dataframe.process;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.ParentTaskAssigningClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcess;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessConfig;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory;
import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsResultProcessor;
import org.elasticsearch.xpack.ml.dataframe.process.DataFrameRowsJoiner;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

public class AnalyticsProcessManager {
    private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);
    private final Settings settings;
    private final Client client;
    private final ExecutorService executorServiceForJob;
    private final ExecutorService executorServiceForProcess;
    private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
    private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<Long, ProcessContext>();
    private final DataFrameAnalyticsAuditor auditor;
    private final TrainedModelProvider trainedModelProvider;
    private final ModelLoadingService modelLoadingService;
    private final ResultsPersisterService resultsPersisterService;
    private final int numAllocatedProcessors;

    public AnalyticsProcessManager(Settings settings, Client client, ThreadPool threadPool, AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, TrainedModelProvider trainedModelProvider, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, int numAllocatedProcessors) {
        this(settings, client, threadPool.generic(), threadPool.executor("ml_job_comms"), analyticsProcessFactory, auditor, trainedModelProvider, modelLoadingService, resultsPersisterService, numAllocatedProcessors);
    }

    public AnalyticsProcessManager(Settings settings, Client client, ExecutorService executorServiceForJob, ExecutorService executorServiceForProcess, AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, TrainedModelProvider trainedModelProvider, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, int numAllocatedProcessors) {
        this.settings = Objects.requireNonNull(settings);
        this.client = Objects.requireNonNull(client);
        this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
        this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
        this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
        this.auditor = Objects.requireNonNull(auditor);
        this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
        this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
        this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
        this.numAllocatedProcessors = numAllocatedProcessors;
    }

    public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory) {
        this.executorServiceForJob.execute(() -> {
            ProcessContext processContext = new ProcessContext(config);
            ConcurrentMap<Long, ProcessContext> concurrentMap = this.processContextByAllocation;
            synchronized (concurrentMap) {
                if (task.isStopping()) {
                    LOGGER.debug("[{}] task is stopping. Marking as complete before creating process context.", (Object)task.getParams().getId());
                    this.auditor.info(config.getId(), "Finished analysis");
                    task.markAsCompleted();
                    return;
                }
                if (this.processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
                    task.setFailed((Exception)((Object)ExceptionsHelper.serverError((String)("[" + config.getId() + "] Could not create process as one already exists"))));
                    return;
                }
            }
            BytesReference state = this.getModelState(config);
            if (processContext.startProcess(dataExtractorFactory, task, state)) {
                this.executorServiceForProcess.execute(() -> ((AnalyticsResultProcessor)processContext.resultProcessor.get()).process((AnalyticsProcess)processContext.process.get()));
                this.executorServiceForProcess.execute(() -> this.processData(task, processContext, state));
            } else {
                this.processContextByAllocation.remove(task.getAllocationId());
                this.auditor.info(config.getId(), "Finished analysis");
                task.markAsCompleted();
            }
        });
    }

    @Nullable
    private BytesReference getModelState(DataFrameAnalyticsConfig config) {
        if (!config.getAnalysis().persistsState()) {
            return null;
        }
        try (ThreadContext.StoredContext ignore = this.client.threadPool().getThreadContext().stashWithOrigin("ml");){
            SearchResponse searchResponse = (SearchResponse)this.client.prepareSearch(new String[]{AnomalyDetectorsIndex.jobStateIndexPattern()}).setSize(1).setQuery((QueryBuilder)QueryBuilders.idsQuery().addIds(new String[]{config.getAnalysis().getStateDocId(config.getId())})).get();
            SearchHit[] hits = searchResponse.getHits().getHits();
            BytesReference bytesReference = hits.length == 0 ? null : hits[0].getSourceRef();
            return bytesReference;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void processData(DataFrameAnalyticsTask task, ProcessContext processContext, BytesReference state) {
        LOGGER.info("[{}] Started loading data", (Object)processContext.config.getId());
        this.auditor.info(processContext.config.getId(), Messages.getMessage((String)"Started loading data"));
        ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(this.client, task.getParentTaskId());
        DataFrameAnalyticsConfig config = processContext.config;
        DataFrameDataExtractor dataExtractor = (DataFrameDataExtractor)processContext.dataExtractor.get();
        AnalyticsProcess process = (AnalyticsProcess)processContext.process.get();
        AnalyticsResultProcessor resultProcessor = (AnalyticsResultProcessor)processContext.resultProcessor.get();
        try {
            this.writeHeaderRecord(dataExtractor, process);
            this.writeDataRows(dataExtractor, process, task);
            process.writeEndOfDataMessage();
            process.flushStream();
            this.restoreState(task, config, state, process);
            LOGGER.info("[{}] Started analyzing", (Object)processContext.config.getId());
            this.auditor.info(processContext.config.getId(), Messages.getMessage((String)"Started analyzing"));
            LOGGER.info("[{}] Waiting for result processor to complete", (Object)config.getId());
            resultProcessor.awaitForCompletion();
            processContext.setFailureReason(resultProcessor.getFailure());
            LOGGER.info("[{}] Result processor has completed", (Object)config.getId());
            this.runInference(parentTaskClient, task, processContext, dataExtractor.getExtractedFields());
            processContext.statsPersister.persistWithRetry((ToXContentObject)task.getStatsHolder().getDataCountsTracker().report(config.getId()), DataCounts::documentId);
            this.refreshDest(parentTaskClient, config);
            this.refreshIndices(parentTaskClient, config.getId());
        }
        catch (Exception e) {
            if (task.isStopping()) {
                String errorMsg = new ParameterizedMessage("[{}] Error while processing data [{}]; task is stopping", (Object)config.getId(), (Object)e.getMessage()).getFormattedMessage();
                LOGGER.debug(errorMsg, (Throwable)e);
            } else {
                String errorMsg = new ParameterizedMessage("[{}] Error while processing data [{}]", (Object)config.getId(), (Object)e.getMessage()).getFormattedMessage();
                LOGGER.error(errorMsg, (Throwable)e);
                processContext.setFailureReason(errorMsg);
            }
        }
        finally {
            this.closeProcess(task);
            this.processContextByAllocation.remove(task.getAllocationId());
            LOGGER.debug("Removed process context for task [{}]; [{}] processes still running", (Object)config.getId(), (Object)this.processContextByAllocation.size());
            if (processContext.getFailureReason() == null) {
                LOGGER.info("[{}] Marking task completed", (Object)config.getId());
                this.auditor.info(config.getId(), "Finished analysis");
                task.markAsCompleted();
            } else {
                LOGGER.error("[{}] Marking task failed; {}", (Object)config.getId(), (Object)processContext.getFailureReason());
                task.setFailed((Exception)((Object)ExceptionsHelper.serverError((String)processContext.getFailureReason())));
            }
        }
    }

    private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process, DataFrameAnalyticsTask task) throws IOException {
        ProgressTracker progressTracker = task.getStatsHolder().getProgressTracker();
        DataCountsTracker dataCountsTracker = task.getStatsHolder().getDataCountsTracker();
        String[] record = new String[dataExtractor.getFieldNames().size() + 2];
        record[record.length - 1] = "";
        long totalRows = process.getConfig().rows();
        long rowsProcessed = 0L;
        while (dataExtractor.hasNext()) {
            Optional<List<DataFrameDataExtractor.Row>> rows = dataExtractor.next();
            if (!rows.isPresent()) continue;
            for (DataFrameDataExtractor.Row row : rows.get()) {
                if (row.shouldSkip()) {
                    dataCountsTracker.incrementSkippedDocsCount();
                    continue;
                }
                String[] rowValues = row.getValues();
                System.arraycopy(rowValues, 0, record, 0, rowValues.length);
                record[record.length - 2] = String.valueOf(row.getChecksum());
                if (!row.isTraining()) continue;
                dataCountsTracker.incrementTrainingDocsCount();
                process.writeRecord(record);
            }
            progressTracker.updateLoadingDataProgress((rowsProcessed += (long)rows.get().size()) >= totalRows ? 100 : (int)((double)rowsProcessed * 100.0 / (double)totalRows));
        }
    }

    private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process) throws IOException {
        List<String> fieldNames = dataExtractor.getFieldNames();
        String[] headerRecord = new String[fieldNames.size() + 2];
        for (int i = 0; i < fieldNames.size(); ++i) {
            headerRecord[i] = fieldNames.get(i);
        }
        headerRecord[headerRecord.length - 2] = ".";
        headerRecord[headerRecord.length - 1] = ".";
        process.writeRecord(headerRecord);
    }

    private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state, AnalyticsProcess<AnalyticsResult> process) {
        if (!config.getAnalysis().persistsState()) {
            LOGGER.debug("[{}] Analysis does not support state", (Object)config.getId());
            return;
        }
        if (state == null) {
            LOGGER.debug("[{}] No model state available to restore", (Object)config.getId());
            return;
        }
        LOGGER.debug("[{}] Restoring from previous model state", (Object)config.getId());
        this.auditor.info(config.getId(), "Restoring from previous model state");
        try (ThreadContext.StoredContext ignore = this.client.threadPool().getThreadContext().stashWithOrigin("ml");){
            process.restoreState(state);
        }
        catch (Exception e) {
            LOGGER.error((Message)new ParameterizedMessage("[{}] Failed to restore state", (Object)process.getConfig().jobId()), (Throwable)e);
            task.setFailed((Exception)((Object)ExceptionsHelper.serverError((String)("Failed to restore state: " + e.getMessage()))));
        }
    }

    private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
        AnalyticsProcess<AnalyticsResult> process = this.processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, this.executorServiceForProcess, this.onProcessCrash(task));
        if (!process.isProcessAlive()) {
            throw ExceptionsHelper.serverError((String)"Failed to start data frame analytics process");
        }
        return process;
    }

    private Consumer<String> onProcessCrash(DataFrameAnalyticsTask task) {
        return reason -> {
            ProcessContext processContext = (ProcessContext)this.processContextByAllocation.get(task.getAllocationId());
            if (processContext != null) {
                processContext.setFailureReason((String)reason);
                processContext.stop();
            }
        };
    }

    private void runInference(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsTask task, ProcessContext processContext, ExtractedFields extractedFields) {
        if (task.isStopping() || processContext.failureReason.get() != null) {
            return;
        }
        if (processContext.config.getAnalysis().supportsInference()) {
            this.refreshDest(parentTaskClient, processContext.config);
            InferenceRunner inferenceRunner = new InferenceRunner(this.settings, (Client)parentTaskClient, this.modelLoadingService, this.resultsPersisterService, task.getParentTaskId(), processContext.config, extractedFields, task.getStatsHolder().getProgressTracker(), task.getStatsHolder().getDataCountsTracker());
            processContext.setInferenceRunner(inferenceRunner);
            inferenceRunner.run(((AnalyticsResultProcessor)processContext.resultProcessor.get()).getLatestModelId());
        }
    }

    private void refreshDest(ParentTaskAssigningClient parentTaskClient, DataFrameAnalyticsConfig config) {
        ClientHelper.executeWithHeaders((Map)config.getHeaders(), (String)"ml", (Client)parentTaskClient, () -> (RefreshResponse)parentTaskClient.execute((ActionType)RefreshAction.INSTANCE, (ActionRequest)new RefreshRequest(new String[]{config.getDest().getIndex()})).actionGet());
    }

    private void refreshIndices(ParentTaskAssigningClient parentTaskClient, String jobId) {
        RefreshRequest refreshRequest = new RefreshRequest(new String[]{AnomalyDetectorsIndex.jobStateIndexPattern(), MlStatsIndex.indexPattern()});
        refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
        LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}", (Object)jobId, (Object)Arrays.toString(refreshRequest.indices())));
        try (ThreadContext.StoredContext ignore = parentTaskClient.threadPool().getThreadContext().stashWithOrigin("ml");){
            parentTaskClient.admin().indices().refresh(refreshRequest).actionGet();
        }
    }

    private void closeProcess(DataFrameAnalyticsTask task) {
        String configId = task.getParams().getId();
        LOGGER.info("[{}] Closing process", (Object)configId);
        ProcessContext processContext = (ProcessContext)this.processContextByAllocation.get(task.getAllocationId());
        try {
            ((AnalyticsProcess)processContext.process.get()).close();
            LOGGER.info("[{}] Closed process", (Object)configId);
        }
        catch (Exception e) {
            LOGGER.error("[" + configId + "] Error closing data frame analyzer process", (Throwable)e);
            String errorMsg = new ParameterizedMessage("[{}] Error closing data frame analyzer process [{}]", (Object)configId, (Object)e.getMessage()).getFormattedMessage();
            processContext.setFailureReason(errorMsg);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void stop(DataFrameAnalyticsTask task) {
        ProcessContext processContext;
        ConcurrentMap<Long, ProcessContext> concurrentMap = this.processContextByAllocation;
        synchronized (concurrentMap) {
            processContext = (ProcessContext)this.processContextByAllocation.get(task.getAllocationId());
        }
        if (processContext != null) {
            LOGGER.debug("[{}] Stopping process", (Object)task.getParams().getId());
            processContext.stop();
        } else {
            LOGGER.debug("[{}] No process context to stop", (Object)task.getParams().getId());
        }
    }

    int getProcessContextCount() {
        return this.processContextByAllocation.size();
    }

    class ProcessContext {
        private final DataFrameAnalyticsConfig config;
        private final SetOnce<AnalyticsProcess<AnalyticsResult>> process = new SetOnce();
        private final SetOnce<DataFrameDataExtractor> dataExtractor = new SetOnce();
        private final SetOnce<AnalyticsResultProcessor> resultProcessor = new SetOnce();
        private final SetOnce<InferenceRunner> inferenceRunner = new SetOnce();
        private final SetOnce<String> failureReason = new SetOnce();
        private final StatsPersister statsPersister;

        ProcessContext(DataFrameAnalyticsConfig config) {
            this.config = Objects.requireNonNull(config);
            this.statsPersister = new StatsPersister(config.getId(), AnalyticsProcessManager.this.resultsPersisterService, AnalyticsProcessManager.this.auditor);
        }

        String getFailureReason() {
            return (String)this.failureReason.get();
        }

        void setFailureReason(String failureReason) {
            if (failureReason == null) {
                return;
            }
            this.failureReason.trySet((Object)failureReason);
        }

        void setInferenceRunner(InferenceRunner inferenceRunner) {
            this.inferenceRunner.set((Object)inferenceRunner);
        }

        synchronized void stop() {
            LOGGER.debug("[{}] Stopping process", (Object)this.config.getId());
            if (this.dataExtractor.get() != null) {
                ((DataFrameDataExtractor)this.dataExtractor.get()).cancel();
            }
            if (this.resultProcessor.get() != null) {
                ((AnalyticsResultProcessor)this.resultProcessor.get()).cancel();
            }
            if (this.inferenceRunner.get() != null) {
                ((InferenceRunner)this.inferenceRunner.get()).cancel();
            }
            if (this.process.get() != null) {
                try {
                    ((AnalyticsProcess)this.process.get()).kill();
                }
                catch (IOException e) {
                    LOGGER.error((Message)new ParameterizedMessage("[{}] Failed to kill process", (Object)this.config.getId()), (Throwable)e);
                }
            }
        }

        synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsTask task, @Nullable BytesReference state) {
            if (task.isStopping()) {
                return false;
            }
            this.dataExtractor.set((Object)dataExtractorFactory.newExtractor(false));
            AnalyticsProcessConfig analyticsProcessConfig = this.createProcessConfig((DataFrameDataExtractor)this.dataExtractor.get(), dataExtractorFactory.getExtractedFields());
            LOGGER.debug("[{}] creating analytics process with config [{}]", (Object)this.config.getId(), (Object)Strings.toString((ToXContent)analyticsProcessConfig));
            if (analyticsProcessConfig.rows() == 0L) {
                LOGGER.info("[{}] no data found to analyze. Will not start analytics native process.", (Object)this.config.getId());
                return false;
            }
            this.process.set((Object)AnalyticsProcessManager.this.createProcess(task, this.config, analyticsProcessConfig, state));
            this.resultProcessor.set((Object)this.createResultProcessor(task, dataExtractorFactory));
            return true;
        }

        private AnalyticsProcessConfig createProcessConfig(DataFrameDataExtractor dataExtractor, ExtractedFields extractedFields) {
            DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary();
            Set<String> categoricalFields = dataExtractor.getCategoricalFields(this.config.getAnalysis());
            int threads = Math.min(this.config.getMaxNumThreads(), AnalyticsProcessManager.this.numAllocatedProcessors);
            return new AnalyticsProcessConfig(this.config.getId(), dataSummary.rows, dataSummary.cols, this.config.getModelMemoryLimit(), threads, this.config.getDest().getResultsField(), categoricalFields, this.config.getAnalysis(), extractedFields);
        }

        private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask task, DataFrameDataExtractorFactory dataExtractorFactory) {
            DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(this.config.getId(), AnalyticsProcessManager.this.settings, task.getParentTaskId(), dataExtractorFactory.newExtractor(true), AnalyticsProcessManager.this.resultsPersisterService);
            return new AnalyticsResultProcessor(this.config, dataFrameRowsJoiner, task.getStatsHolder(), AnalyticsProcessManager.this.trainedModelProvider, AnalyticsProcessManager.this.auditor, this.statsPersister, ((DataFrameDataExtractor)this.dataExtractor.get()).getExtractedFields());
        }
    }
}

