/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.coordinator.strategy.assignment;

import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.coordinator.ClusterManager;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.coordinator.ServerNode;
import org.apache.uniffle.coordinator.strategy.assignment.AbstractAssignmentStrategy;
import org.apache.uniffle.coordinator.strategy.assignment.PartitionRangeAssignment;
import org.apache.uniffle.guava.annotations.VisibleForTesting;
import org.apache.uniffle.guava.collect.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PartitionBalanceAssignmentStrategy
extends AbstractAssignmentStrategy {
    private static final Logger LOG = LoggerFactory.getLogger(PartitionBalanceAssignmentStrategy.class);
    private ClusterManager clusterManager;
    private Map<ServerNode, PartitionAssignmentInfo> serverToPartitions = JavaUtils.newConcurrentMap();

    public PartitionBalanceAssignmentStrategy(ClusterManager clusterManager, CoordinatorConf conf) {
        super(conf);
        this.clusterManager = clusterManager;
    }

    @Override
    public PartitionRangeAssignment assign(int totalPartitionNum, int partitionNumPerRange, int replica, Set<String> requiredTags, int requiredShuffleServerNumber, int estimateTaskConcurrency) {
        return this.assign(totalPartitionNum, partitionNumPerRange, replica, requiredTags, requiredShuffleServerNumber, estimateTaskConcurrency, Sets.newConcurrentHashSet());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public PartitionRangeAssignment assign(int totalPartitionNum, int partitionNumPerRange, int replica, Set<String> requiredTags, int requiredShuffleServerNumber, int estimateTaskConcurrency, Set<String> excludeServerNodes) {
        SortedMap<PartitionRange, List<ServerNode>> assignments;
        if (partitionNumPerRange != 1) {
            throw new RssException("PartitionNumPerRange must be one");
        }
        PartitionBalanceAssignmentStrategy partitionBalanceAssignmentStrategy = this;
        synchronized (partitionBalanceAssignmentStrategy) {
            int assignmentMaxNum;
            List<ServerNode> nodes = this.clusterManager.getServerList(requiredTags, excludeServerNodes);
            ConcurrentHashMap newPartitionInfos = JavaUtils.newConcurrentMap();
            for (ServerNode node : nodes) {
                newPartitionInfos.computeIfAbsent(node, key -> {
                    PartitionAssignmentInfo partitionInfo;
                    if (this.serverToPartitions.containsKey(node)) {
                        partitionInfo = this.serverToPartitions.get(node);
                        if (partitionInfo.getTimestamp() < node.getTimestamp()) {
                            partitionInfo.resetPartitionNum();
                            partitionInfo.setTimestamp(node.getTimestamp());
                        }
                    } else {
                        partitionInfo = new PartitionAssignmentInfo();
                    }
                    return partitionInfo;
                });
            }
            this.serverToPartitions = newPartitionInfos;
            int averagePartitions = totalPartitionNum * replica / this.clusterManager.getShuffleNodesMax();
            final int assignPartitions = Math.max(averagePartitions, 1);
            nodes.sort(new Comparator<ServerNode>(){

                @Override
                public int compare(ServerNode o1, ServerNode o2) {
                    PartitionAssignmentInfo partitionInfo1 = (PartitionAssignmentInfo)PartitionBalanceAssignmentStrategy.this.serverToPartitions.get(o1);
                    PartitionAssignmentInfo partitionInfo2 = (PartitionAssignmentInfo)PartitionBalanceAssignmentStrategy.this.serverToPartitions.get(o2);
                    double v1 = (double)o1.getAvailableMemory() * 1.0 / (double)(partitionInfo1.getPartitionNum() + assignPartitions);
                    double v2 = (double)o2.getAvailableMemory() * 1.0 / (double)(partitionInfo2.getPartitionNum() + assignPartitions);
                    return Double.compare(v2, v1);
                }
            });
            if (nodes.isEmpty() || nodes.size() < replica) {
                throw new RssException("There isn't enough shuffle servers");
            }
            int expectNum = assignmentMaxNum = this.clusterManager.getShuffleNodesMax();
            if (requiredShuffleServerNumber < assignmentMaxNum && requiredShuffleServerNumber > 0) {
                expectNum = requiredShuffleServerNumber;
            }
            if (nodes.size() < expectNum) {
                LOG.warn("Can't get expected servers [{}] and found only [{}]", (Object)expectNum, (Object)nodes.size());
                expectNum = nodes.size();
            }
            List<ServerNode> candidatesNodes = this.getCandidateNodes(nodes, expectNum);
            assignments = this.getPartitionAssignment(totalPartitionNum, partitionNumPerRange, replica, candidatesNodes, estimateTaskConcurrency);
            assignments.values().stream().flatMap(Collection::stream).forEach(server -> this.serverToPartitions.get(server).incrementPartitionNum());
        }
        return new PartitionRangeAssignment(assignments);
    }

    @VisibleForTesting
    Map<ServerNode, PartitionAssignmentInfo> getServerToPartitions() {
        return this.serverToPartitions;
    }

    static class PartitionAssignmentInfo {
        int partitionNum = 0;
        long timestamp = System.currentTimeMillis();

        PartitionAssignmentInfo() {
        }

        public int getPartitionNum() {
            return this.partitionNum;
        }

        public void resetPartitionNum() {
            this.partitionNum = 0;
        }

        public void incrementPartitionNum() {
            ++this.partitionNum;
        }

        public void incrementPartitionNum(int val) {
            this.partitionNum += val;
        }

        public long getTimestamp() {
            return this.timestamp;
        }

        public void setTimestamp(long timestamp) {
            this.timestamp = timestamp;
        }
    }
}

