/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.shuffle.manager;

import io.grpc.stub.StreamObserver;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import org.apache.spark.shuffle.ShuffleHandleInfo;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.ShuffleManagerGrpc;
import org.apache.uniffle.shaded.com.google.common.collect.Sets;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShuffleManagerGrpcService
extends ShuffleManagerGrpc.ShuffleManagerImplBase {
    private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcService.class);
    private final Map<Integer, RssShuffleStatus> shuffleStatus = JavaUtils.newConcurrentMap();
    private final Map<Integer, ShuffleServerFailureRecord> shuffleWrtieStatus = JavaUtils.newConcurrentMap();
    private final RssShuffleManagerInterface shuffleManager;

    public ShuffleManagerGrpcService(RssShuffleManagerInterface shuffleManager) {
        this.shuffleManager = shuffleManager;
    }

    @Override
    public void reportShuffleWriteFailure(RssProtos.ReportShuffleWriteFailureRequest request, StreamObserver<RssProtos.ReportShuffleWriteFailureResponse> responseObserver) {
        boolean reSubmitWholeStage;
        RssProtos.StatusCode code;
        String msg;
        String appId = request.getAppId();
        int shuffleId = request.getShuffleId();
        int stageAttemptNumber = request.getStageAttemptNumber();
        List<RssProtos.ShuffleServerId> shuffleServerIdsList = request.getShuffleServerIdsList();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            msg = String.format("got a wrong shuffle write failure report from appId: %s, expected appId: %s", appId, this.shuffleManager.getAppId());
            LOG.warn(msg);
            code = RssProtos.StatusCode.INVALID_REQUEST;
            reSubmitWholeStage = false;
        } else {
            ConcurrentHashMap shuffleServerInfoIntegerMap = JavaUtils.newConcurrentMap();
            List<ShuffleServerInfo> shuffleServerInfos = ShuffleServerInfo.fromProto(shuffleServerIdsList);
            shuffleServerInfos.forEach(shuffleServerInfo -> shuffleServerInfoIntegerMap.put(shuffleServerInfo.getId(), new AtomicInteger(0)));
            ShuffleServerFailureRecord shuffleServerFailureRecord = this.shuffleWrtieStatus.computeIfAbsent(shuffleId, key -> new ShuffleServerFailureRecord(shuffleServerInfoIntegerMap, stageAttemptNumber));
            boolean resetflag = shuffleServerFailureRecord.resetStageAttemptIfNecessary(stageAttemptNumber);
            if (resetflag) {
                msg = String.format("got an old stage(%d vs %d) shuffle write failure report, which should be impossible.", shuffleServerFailureRecord.getStageAttempt(), stageAttemptNumber);
                LOG.warn(msg);
                code = RssProtos.StatusCode.INVALID_REQUEST;
                reSubmitWholeStage = false;
            } else {
                code = RssProtos.StatusCode.SUCCESS;
                boolean fetchFailureflag = shuffleServerFailureRecord.incPartitionWriteFailure(stageAttemptNumber, shuffleServerInfos, this.shuffleManager);
                if (fetchFailureflag) {
                    reSubmitWholeStage = true;
                    msg = String.format("report shuffle write failure as maximum number(%d) of shuffle write is occurred", this.shuffleManager.getMaxFetchFailures());
                } else {
                    reSubmitWholeStage = false;
                    msg = "don't report shuffle write failure";
                }
            }
        }
        RssProtos.ReportShuffleWriteFailureResponse reply = RssProtos.ReportShuffleWriteFailureResponse.newBuilder().setStatus(code).setReSubmitWholeStage(reSubmitWholeStage).setMsg(msg).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void reportShuffleFetchFailure(RssProtos.ReportShuffleFetchFailureRequest request, StreamObserver<RssProtos.ReportShuffleFetchFailureResponse> responseObserver) {
        boolean reSubmitWholeStage;
        RssProtos.StatusCode code;
        String msg;
        String appId = request.getAppId();
        int stageAttempt = request.getStageAttemptId();
        int partitionId = request.getPartitionId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            msg = String.format("got a wrong shuffle fetch failure report from appId: %s, expected appId: %s", appId, this.shuffleManager.getAppId());
            LOG.warn(msg);
            code = RssProtos.StatusCode.INVALID_REQUEST;
            reSubmitWholeStage = false;
        } else {
            RssShuffleStatus status = this.shuffleStatus.computeIfAbsent(request.getShuffleId(), key -> {
                int partitionNum = this.shuffleManager.getPartitionNum((int)key);
                return new RssShuffleStatus(partitionNum, stageAttempt);
            });
            int c = status.resetStageAttemptIfNecessary(stageAttempt);
            if (c < 0) {
                msg = String.format("got an old stage(%d vs %d) shuffle fetch failure report, which should be impossible.", status.getStageAttempt(), stageAttempt);
                LOG.warn(msg);
                code = RssProtos.StatusCode.INVALID_REQUEST;
                reSubmitWholeStage = false;
            } else {
                code = RssProtos.StatusCode.SUCCESS;
                status.incPartitionFetchFailure(stageAttempt, partitionId);
                int fetchFailureNum = status.getPartitionFetchFailureNum(stageAttempt, partitionId);
                if (fetchFailureNum >= this.shuffleManager.getMaxFetchFailures()) {
                    reSubmitWholeStage = true;
                    msg = String.format("report shuffle fetch failure as maximum number(%d) of shuffle fetch is occurred", this.shuffleManager.getMaxFetchFailures());
                } else {
                    reSubmitWholeStage = false;
                    msg = "don't report shuffle fetch failure";
                }
            }
        }
        RssProtos.ReportShuffleFetchFailureResponse reply = RssProtos.ReportShuffleFetchFailureResponse.newBuilder().setStatus(code).setReSubmitWholeStage(reSubmitWholeStage).setMsg(msg).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void getPartitionToShufflerServer(RssProtos.PartitionToShuffleServerRequest request, StreamObserver<RssProtos.PartitionToShuffleServerResponse> responseObserver) {
        RssProtos.PartitionToShuffleServerResponse reply;
        int shuffleId = request.getShuffleId();
        ShuffleHandleInfo shuffleHandleInfoByShuffleId = this.shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
        if (shuffleHandleInfoByShuffleId != null) {
            RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
            Map<Integer, List<ShuffleServerInfo>> partitionToServers = shuffleHandleInfoByShuffleId.getPartitionToServers();
            ConcurrentHashMap<Integer, RssProtos.GetShuffleServerListResponse> protopartitionToServers = JavaUtils.newConcurrentMap();
            for (Map.Entry<Integer, List<ShuffleServerInfo>> integerListEntry : partitionToServers.entrySet()) {
                List<RssProtos.ShuffleServerId> shuffleServerIds = ShuffleServerInfo.toProto(integerListEntry.getValue());
                RssProtos.GetShuffleServerListResponse getShuffleServerListResponse = RssProtos.GetShuffleServerListResponse.newBuilder().addAllServers(shuffleServerIds).build();
                protopartitionToServers.put(integerListEntry.getKey(), getShuffleServerListResponse);
            }
            RemoteStorageInfo remoteStorage = shuffleHandleInfoByShuffleId.getRemoteStorage();
            RssProtos.RemoteStorageInfo.Builder protosRemoteStage = RssProtos.RemoteStorageInfo.newBuilder().setPath(remoteStorage.getPath()).putAllConfItems(remoteStorage.getConfItems());
            reply = RssProtos.PartitionToShuffleServerResponse.newBuilder().setStatus(code).putAllPartitionToShuffleServer(protopartitionToServers).setRemoteStorageInfo(protosRemoteStage).build();
        } else {
            RssProtos.StatusCode code = RssProtos.StatusCode.INVALID_REQUEST;
            reply = RssProtos.PartitionToShuffleServerResponse.newBuilder().setStatus(code).putAllPartitionToShuffleServer(null).build();
        }
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void reassignShuffleServers(RssProtos.ReassignServersRequest request, StreamObserver<RssProtos.ReassignServersReponse> responseObserver) {
        int stageId = request.getStageId();
        int stageAttemptNumber = request.getStageAttemptNumber();
        int shuffleId = request.getShuffleId();
        int numPartitions = request.getNumPartitions();
        boolean needReassign = this.shuffleManager.reassignAllShuffleServersForWholeStage(stageId, stageAttemptNumber, shuffleId, numPartitions);
        RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
        RssProtos.ReassignServersReponse reply = RssProtos.ReassignServersReponse.newBuilder().setStatus(code).setNeedReassign(needReassign).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void reassignFaultyShuffleServer(RssProtos.RssReassignFaultyShuffleServerRequest request, StreamObserver<RssProtos.RssReassignFaultyShuffleServerResponse> responseObserver) {
        ShuffleServerInfo shuffleServerInfo = this.shuffleManager.reassignFaultyShuffleServerForTasks(request.getShuffleId(), Sets.newHashSet(request.getPartitionIdsList()), request.getFaultyShuffleServerId());
        RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
        RssProtos.RssReassignFaultyShuffleServerResponse reply = RssProtos.RssReassignFaultyShuffleServerResponse.newBuilder().setStatus(code).setServer(ShuffleServerInfo.convertToShuffleServerId(shuffleServerInfo)).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    public void unregisterShuffle(int shuffleId) {
        this.shuffleStatus.remove(shuffleId);
    }

    private static class RssShuffleStatus {
        private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
        private final ReentrantReadWriteLock.ReadLock readLock = this.lock.readLock();
        private final ReentrantReadWriteLock.WriteLock writeLock = this.lock.writeLock();
        private final int[] partitions;
        private int stageAttempt;

        private RssShuffleStatus(int partitionNum, int stageAttempt) {
            this.stageAttempt = stageAttempt;
            this.partitions = new int[partitionNum];
        }

        private <T> T withReadLock(Supplier<T> fn) {
            this.readLock.lock();
            try {
                T t = fn.get();
                return t;
            }
            finally {
                this.readLock.unlock();
            }
        }

        private <T> T withWriteLock(Supplier<T> fn) {
            this.writeLock.lock();
            try {
                T t = fn.get();
                return t;
            }
            finally {
                this.writeLock.unlock();
            }
        }

        public int getStageAttempt() {
            return this.withReadLock(() -> this.stageAttempt);
        }

        public int resetStageAttemptIfNecessary(int stageAttempt) {
            return this.withWriteLock(() -> {
                if (this.stageAttempt < stageAttempt) {
                    Arrays.fill(this.partitions, 0);
                    this.stageAttempt = stageAttempt;
                    return 1;
                }
                if (this.stageAttempt > stageAttempt) {
                    return -1;
                }
                return 0;
            });
        }

        public void incPartitionFetchFailure(int stageAttempt, int partition) {
            this.withWriteLock(() -> {
                if (this.stageAttempt == stageAttempt) {
                    this.partitions[partition] = this.partitions[partition] + 1;
                }
                return null;
            });
        }

        public int getPartitionFetchFailureNum(int stageAttempt, int partition) {
            return this.withReadLock(() -> {
                if (this.stageAttempt != stageAttempt) {
                    return 0;
                }
                return this.partitions[partition];
            });
        }
    }

    private static class ShuffleServerFailureRecord {
        private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
        private final ReentrantReadWriteLock.ReadLock readLock = this.lock.readLock();
        private final ReentrantReadWriteLock.WriteLock writeLock = this.lock.writeLock();
        private final Map<String, AtomicInteger> shuffleServerFailureRecordCount;
        private int stageAttemptNumber;

        private ShuffleServerFailureRecord(Map<String, AtomicInteger> shuffleServerFailureRecordCount, int stageAttemptNumber) {
            this.shuffleServerFailureRecordCount = shuffleServerFailureRecordCount;
            this.stageAttemptNumber = stageAttemptNumber;
        }

        private <T> T withReadLock(Supplier<T> fn) {
            this.readLock.lock();
            try {
                T t = fn.get();
                return t;
            }
            finally {
                this.readLock.unlock();
            }
        }

        private <T> T withWriteLock(Supplier<T> fn) {
            this.writeLock.lock();
            try {
                T t = fn.get();
                return t;
            }
            finally {
                this.writeLock.unlock();
            }
        }

        public int getStageAttempt() {
            return this.withReadLock(() -> this.stageAttemptNumber);
        }

        public boolean resetStageAttemptIfNecessary(int stageAttemptNumber) {
            return this.withWriteLock(() -> {
                if (this.stageAttemptNumber < stageAttemptNumber) {
                    this.shuffleServerFailureRecordCount.clear();
                    this.stageAttemptNumber = stageAttemptNumber;
                    return false;
                }
                if (this.stageAttemptNumber > stageAttemptNumber) {
                    return true;
                }
                return false;
            });
        }

        public boolean incPartitionWriteFailure(int stageAttemptNumber, List<ShuffleServerInfo> shuffleServerInfos, RssShuffleManagerInterface shuffleManager) {
            return this.withWriteLock(() -> {
                if (this.stageAttemptNumber != stageAttemptNumber) {
                    return false;
                }
                shuffleServerInfos.forEach(shuffleServerInfo -> this.shuffleServerFailureRecordCount.computeIfAbsent(shuffleServerInfo.getId(), k -> new AtomicInteger()).incrementAndGet());
                ArrayList<Map.Entry<String, AtomicInteger>> list = new ArrayList<Map.Entry<String, AtomicInteger>>(this.shuffleServerFailureRecordCount.entrySet());
                if (!list.isEmpty()) {
                    Collections.sort(list, (o1, o2) -> ((AtomicInteger)o1.getValue()).get() - ((AtomicInteger)o2.getValue()).get());
                    Map.Entry shuffleServerInfoIntegerEntry = (Map.Entry)list.get(0);
                    if (((AtomicInteger)shuffleServerInfoIntegerEntry.getValue()).get() > shuffleManager.getMaxFetchFailures()) {
                        shuffleManager.addFailuresShuffleServerInfos((String)shuffleServerInfoIntegerEntry.getKey());
                        return true;
                    }
                }
                return false;
            });
        }
    }
}

