/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle.writer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.spark.Aggregator;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.FetchFailedException;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.BufferManagerOptions;
import org.apache.spark.shuffle.writer.PartitionLengthStatistic;
import org.apache.spark.shuffle.writer.TaskAttemptAssignment;
import org.apache.spark.shuffle.writer.WriteBufferManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.impl.TrackingBlockStatus;
import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
import org.apache.uniffle.client.request.RssReportShuffleWriteMetricRequest;
import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteMetricResponse;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssSendFailedException;
import org.apache.uniffle.common.exception.RssWaitFailedException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.com.google.common.collect.Lists;
import org.apache.uniffle.shaded.com.google.common.collect.Maps;
import org.apache.uniffle.shaded.com.google.common.collect.Sets;
import org.apache.uniffle.shaded.com.google.common.util.concurrent.Uninterruptibles;
import org.apache.uniffle.shaded.org.apache.commons.collections4.CollectionUtils;
import org.apache.uniffle.storage.util.StorageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Function1;
import scala.Option;
import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;

public class RssShuffleWriter<K, V, C>
extends ShuffleWriter<K, V> {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleWriter.class);
    private static final String DUMMY_HOST = "dummy_host";
    private static final int DUMMY_PORT = 99999;
    public static final String DEFAULT_ERROR_MESSAGE = "Default Error Message";
    private final String appId;
    private final int shuffleId;
    private final ShuffleHandleInfo shuffleHandleInfo;
    private WriteBufferManager bufferManager;
    private String taskId;
    private final int numMaps;
    private final ShuffleDependency<K, V, C> shuffleDependency;
    private final Partitioner partitioner;
    private final RssShuffleManager shuffleManager;
    private final boolean shouldPartition;
    private final long sendCheckTimeout;
    private final long sendCheckInterval;
    private final int bitmapSplitNum;
    private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> serverToPartitionToBlockIds;
    private final ShuffleWriteClient shuffleWriteClient;
    private final Set<ShuffleServerInfo> shuffleServersForData;
    private final PartitionLengthStatistic partitionLengthStatistic;
    protected final boolean isMemoryShuffleEnabled;
    private final Function<String, Boolean> taskFailureCallback;
    private final Set<Long> blockIds = Sets.newConcurrentHashSet();
    private TaskContext taskContext;
    private SparkConf sparkConf;
    private boolean blockFailSentRetryEnabled;
    private int blockFailSentRetryMaxTimes = 1;
    protected final long taskAttemptId;
    protected final ShuffleWriteMetrics shuffleWriteMetrics;
    private final BlockingQueue<Object> finishEventQueue = new LinkedBlockingQueue<Object>();
    private TaskAttemptAssignment taskAttemptAssignment;
    private static final Set<StatusCode> STATUS_CODE_WITHOUT_BLOCK_RESEND = Sets.newHashSet(StatusCode.NO_REGISTER);
    private final Supplier<ShuffleManagerClient> managerClientSupplier;
    private boolean enableWriteFailureRetry;
    private Set<ShuffleServerInfo> recordReportFailedShuffleservers;
    private long totalShuffleWriteMills = 0L;
    private long checkSendResultMills = 0L;
    private boolean isShuffleWriteFailed = false;
    private Optional<String> shuffleWriteFailureReason = Optional.empty();

    @VisibleForTesting
    public RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, WriteBufferManager bufferManager, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, ShuffleHandleInfo shuffleHandleInfo, TaskContext context) {
        this(appId, shuffleId, taskId, taskAttemptId, shuffleWriteMetrics, shuffleManager, sparkConf, shuffleWriteClient, managerClientSupplier, rssHandle, (String tid) -> true, shuffleHandleInfo, context);
        this.bufferManager = bufferManager;
        this.taskAttemptAssignment = new TaskAttemptAssignment(taskAttemptId, shuffleHandleInfo);
    }

    private RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, ShuffleHandleInfo shuffleHandleInfo, TaskContext context) {
        LOG.info("RssShuffle start write taskAttemptId[{}] data with RssHandle[appId {}, shuffleId {}].", new Object[]{taskAttemptId, rssHandle.getAppId(), rssHandle.getShuffleId()});
        this.shuffleManager = shuffleManager;
        this.appId = appId;
        this.shuffleId = shuffleId;
        this.taskId = taskId;
        this.taskAttemptId = taskAttemptId;
        this.numMaps = rssHandle.getNumMaps();
        this.shuffleWriteMetrics = shuffleWriteMetrics;
        this.shuffleDependency = rssHandle.getDependency();
        this.partitioner = this.shuffleDependency.partitioner();
        this.shouldPartition = this.partitioner.numPartitions() > 1;
        this.sendCheckTimeout = (Long)sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
        this.sendCheckInterval = (Long)sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
        this.bitmapSplitNum = (Integer)sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
        this.serverToPartitionToBlockIds = Maps.newHashMap();
        this.shuffleWriteClient = shuffleWriteClient;
        this.shuffleServersForData = shuffleHandleInfo.getServers();
        this.partitionLengthStatistic = new PartitionLengthStatistic(this.partitioner.numPartitions());
        this.isMemoryShuffleEnabled = this.isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
        this.taskFailureCallback = taskFailureCallback;
        this.shuffleHandleInfo = shuffleHandleInfo;
        this.taskContext = context;
        this.sparkConf = sparkConf;
        this.managerClientSupplier = managerClientSupplier;
        this.blockFailSentRetryEnabled = sparkConf.getBoolean("spark." + RssClientConf.RSS_CLIENT_REASSIGN_ENABLED.key(), RssClientConf.RSS_CLIENT_REASSIGN_ENABLED.defaultValue().booleanValue());
        this.blockFailSentRetryMaxTimes = RssSparkConfig.toRssConf(sparkConf).get(RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES);
        this.enableWriteFailureRetry = RssSparkConfig.toRssConf(sparkConf).get(RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
        this.recordReportFailedShuffleservers = Sets.newConcurrentHashSet();
    }

    public RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, TaskContext context) {
        this(appId, shuffleId, taskId, taskAttemptId, shuffleWriteMetrics, shuffleManager, sparkConf, shuffleWriteClient, shuffleManager.getShuffleManagerClientSupplier(), rssHandle, taskFailureCallback, context);
    }

    public RssShuffleWriter(String appId, int shuffleId, String taskId, long taskAttemptId, ShuffleWriteMetrics shuffleWriteMetrics, RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, TaskContext context) {
        this(appId, shuffleId, taskId, taskAttemptId, shuffleWriteMetrics, shuffleManager, sparkConf, shuffleWriteClient, managerClientSupplier, rssHandle, taskFailureCallback, shuffleManager.getShuffleHandleInfo(context.stageId(), context.stageAttemptNumber(), rssHandle, true), context);
        WriteBufferManager bufferManager;
        this.taskAttemptAssignment = new TaskAttemptAssignment(taskAttemptId, this.shuffleHandleInfo);
        BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
        this.bufferManager = bufferManager = new WriteBufferManager(shuffleId, taskId, taskAttemptId, bufferOptions, rssHandle.getDependency().serializer(), context.taskMemoryManager(), shuffleWriteMetrics, RssSparkConfig.toRssConf(sparkConf), this::processShuffleBlockInfos, this::getPartitionAssignedServers, context.stageAttemptNumber());
    }

    @VisibleForTesting
    protected List<ShuffleServerInfo> getPartitionAssignedServers(int partitionId) {
        return this.taskAttemptAssignment.retrieve(partitionId);
    }

    private boolean isMemoryShuffleEnabled(String storageType) {
        return StorageType.withMemory(StorageType.valueOf(storageType));
    }

    public void write(Iterator<Product2<K, V>> records) {
        try {
            this.writeImpl(records);
        }
        catch (Exception e) {
            if (e instanceof RssException) {
                this.isShuffleWriteFailed = true;
                this.shuffleWriteFailureReason = Optional.ofNullable(e.getMessage());
            }
            this.taskFailureCallback.apply(this.taskId);
            if (this.enableWriteFailureRetry) {
                this.throwFetchFailedIfNecessary(e, Sets.newConcurrentHashSet());
            }
            throw e;
        }
    }

    protected void writeImpl(Iterator<Product2<K, V>> records) {
        long writeDurationMs;
        List<ShuffleBlockInfo> shuffleBlockInfos;
        boolean isCombine = this.shuffleDependency.mapSideCombine();
        Iterator iterator = records;
        if (isCombine) {
            if (RssSparkConfig.toRssConf(this.sparkConf).get(RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED).booleanValue()) {
                iterator = ((Aggregator)this.shuffleDependency.aggregator().get()).combineValuesByKey(records, this.taskContext);
            } else {
                Function1 combiner = ((Aggregator)this.shuffleDependency.aggregator().get()).createCombiner();
                iterator = records.map(x -> new Tuple2(x._1(), combiner.apply(x._2())));
            }
        }
        long recordCount = 0L;
        while (iterator.hasNext()) {
            ++recordCount;
            this.checkDataIfAnyFailure();
            Product2 record = (Product2)iterator.next();
            Object key = record._1();
            int partition = this.getPartition(key);
            shuffleBlockInfos = this.bufferManager.addRecord(partition, record._1(), record._2());
            if (shuffleBlockInfos == null || shuffleBlockInfos.isEmpty()) continue;
            this.processShuffleBlockInfos(shuffleBlockInfos);
        }
        long start = System.currentTimeMillis();
        shuffleBlockInfos = this.bufferManager.clear(1.0);
        if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
            this.processShuffleBlockInfos(shuffleBlockInfos);
        }
        long checkStartTs = System.currentTimeMillis();
        this.checkAllBufferSpilled();
        this.checkSentRecordCount(recordCount);
        this.checkBlockSendResult(new HashSet<Long>(this.blockIds));
        this.checkSentBlockCount();
        this.bufferManager.getShuffleServerPushCostTracker().statistics();
        long commitStartTs = System.currentTimeMillis();
        long checkDuration = commitStartTs - checkStartTs;
        if (!this.isMemoryShuffleEnabled) {
            this.sendCommit();
        }
        this.totalShuffleWriteMills = writeDurationMs = this.bufferManager.getWriteTime() + (System.currentTimeMillis() - start);
        this.checkSendResultMills = checkDuration;
        this.shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(writeDurationMs));
        LOG.info("Finish write shuffle for appId[" + this.appId + "], shuffleId[" + this.shuffleId + "], taskId[" + this.taskId + "] with write " + writeDurationMs + " ms, include checkSendResult[" + checkDuration + "], commit[" + (System.currentTimeMillis() - commitStartTs) + "], " + this.bufferManager.getManagerCostInfo());
    }

    private void checkAllBufferSpilled() {
        if (this.bufferManager.getBuffers().size() > 0) {
            throw new RssSendFailedException("Potential data loss due to existing remaining data buffers that are not flushed. This should not happen.");
        }
    }

    private void checkSentRecordCount(long recordCount) {
        if (recordCount != this.bufferManager.getRecordCount()) {
            String errorMsg = "Potential record loss may have occurred while preparing to send blocks for task[" + this.taskId + "]";
            throw new RssSendFailedException(errorMsg);
        }
    }

    private void checkSentBlockCount() {
        long expected = this.blockIds.size();
        long bufferManagerTracked = this.bufferManager.getBlockCount();
        if (this.serverToPartitionToBlockIds == null) {
            throw new RssException("serverToPartitionToBlockIds should not be null");
        }
        HashSet blockIds = new HashSet();
        for (Map<Integer, Set<Long>> partitionBlockIds : this.serverToPartitionToBlockIds.values()) {
            partitionBlockIds.values().forEach(x -> blockIds.addAll(x));
        }
        long serverTracked = blockIds.size();
        if (expected != serverTracked || expected != bufferManagerTracked) {
            throw new RssSendFailedException("Potential block loss may occur for task[" + this.taskId + "]. BlockId number expected: " + expected + ", serverTracked: " + serverTracked + ", bufferManagerTracked: " + bufferManagerTracked);
        }
    }

    public long[] getPartitionLengths() {
        return new long[0];
    }

    @VisibleForTesting
    protected List<CompletableFuture<Long>> processShuffleBlockInfos(List<ShuffleBlockInfo> shuffleBlockInfoList) {
        if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) {
            shuffleBlockInfoList.forEach(sbi -> {
                long blockId = sbi.getBlockId();
                this.blockIds.add(blockId);
                int partitionId = sbi.getPartitionId();
                sbi.getShuffleServerInfos().forEach(shuffleServerInfo -> {
                    Map pToBlockIds = this.serverToPartitionToBlockIds.computeIfAbsent((ShuffleServerInfo)shuffleServerInfo, k -> Maps.newHashMap());
                    pToBlockIds.computeIfAbsent(partitionId, v -> Sets.newHashSet()).add(blockId);
                });
            });
            return this.postBlockEvent(shuffleBlockInfoList);
        }
        return Collections.emptyList();
    }

    protected List<CompletableFuture<Long>> postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
        ArrayList<CompletableFuture<Long>> futures = new ArrayList<CompletableFuture<Long>>();
        for (AddBlockEvent event : this.bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
            for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
                block.withCompletionCallback((b, isSuccessful) -> {
                    if (!this.blockFailSentRetryEnabled || isSuccessful) {
                        this.bufferManager.releaseBlockResource(b);
                        this.partitionLengthStatistic.inc(b);
                    }
                });
            }
            event.addCallback(() -> {
                boolean ret = this.finishEventQueue.add(new Object());
                if (!ret) {
                    LOG.error("Add event " + event + " to finishEventQueue fail");
                }
            });
            futures.add(this.shuffleManager.sendData(event));
        }
        return futures;
    }

    protected void internalCheckBlockSendResult() {
        this.checkBlockSendResult(this.blockIds);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @VisibleForTesting
    protected void checkBlockSendResult(Set<Long> blockIds) {
        boolean interrupted = false;
        try {
            long remainingMs = this.sendCheckTimeout;
            long end = System.currentTimeMillis() + remainingMs;
            while (true) {
                try {
                    Object event;
                    do {
                        this.finishEventQueue.clear();
                        this.checkDataIfAnyFailure();
                        Set<Long> successBlockIds = this.shuffleManager.getSuccessBlockIds(this.taskId);
                        blockIds.removeAll(successBlockIds);
                    } while (!blockIds.isEmpty() && (!this.finishEventQueue.isEmpty() || (event = this.finishEventQueue.poll(remainingMs = Math.max(end - System.currentTimeMillis(), 0L), TimeUnit.MILLISECONDS)) != null));
                }
                catch (InterruptedException e) {
                    interrupted = true;
                    continue;
                }
                break;
            }
            if (!blockIds.isEmpty()) {
                String errorMsg = "Timeout: Task[" + this.taskId + "] failed because " + blockIds.size() + " blocks can't be sent to shuffle server in " + this.sendCheckTimeout + " ms.";
                LOG.error(errorMsg);
                throw new RssWaitFailedException(errorMsg);
            }
        }
        finally {
            if (interrupted) {
                Thread.currentThread().interrupt();
            }
        }
    }

    protected void checkDataIfAnyFailure() {
        if (this.blockFailSentRetryEnabled) {
            this.collectFailedBlocksToResend();
        } else {
            String errorMsg = this.getFirstBlockFailure();
            if (errorMsg != null) {
                throw new RssSendFailedException("Fail to send the block. Error: " + errorMsg);
            }
        }
    }

    private String getFirstBlockFailure() {
        Set<Long> failedBlockIds = this.shuffleManager.getFailedBlockIds(this.taskId);
        if (!failedBlockIds.isEmpty()) {
            List<TrackingBlockStatus> trackingBlockStatues = this.shuffleManager.getBlockIdsFailedSendTracker(this.taskId).getFailedBlockStatus(failedBlockIds.iterator().next());
            String errorMsg = DEFAULT_ERROR_MESSAGE;
            if (CollectionUtils.isNotEmpty(trackingBlockStatues)) {
                errorMsg = trackingBlockStatues.get(0).getStatusCode().name();
            }
            LOG.error("Errors on sending blocks for task[{}]. {} blocks can't be sent to remote servers: {}", new Object[]{this.taskId, failedBlockIds.size(), this.shuffleManager.getBlockIdsFailedSendTracker(this.taskId).getFaultyShuffleServers()});
            return errorMsg;
        }
        return null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void collectFailedBlocksToResend() {
        List<TrackingBlockStatus> failedBlockStatus;
        if (!this.blockFailSentRetryEnabled) {
            return;
        }
        FailedBlockSendTracker failedTracker = this.shuffleManager.getBlockIdsFailedSendTracker(this.taskId);
        if (failedTracker == null) {
            return;
        }
        this.reassignOnPartitionNeedSplit(failedTracker);
        Set<Long> failedBlockIds = failedTracker.getFailedBlockIds();
        if (CollectionUtils.isEmpty(failedBlockIds)) {
            return;
        }
        boolean isFastFail = false;
        HashSet<TrackingBlockStatus> resendCandidates = new HashSet<TrackingBlockStatus>();
        for (Long blockId : failedBlockIds) {
            List<TrackingBlockStatus> list = failedBlockStatus = failedTracker.getFailedBlockStatus(blockId);
            synchronized (list) {
                int retryIndex = failedBlockStatus.stream().map(x -> x.getShuffleBlockInfo().getRetryCnt()).max(Comparator.comparing(Integer::valueOf)).get();
                if (retryIndex >= this.blockFailSentRetryMaxTimes) {
                    LOG.error("Partial blocks for taskId: [{}] retry exceeding the max retry times: [{}]. Fast fail! faulty server list: {}", new Object[]{this.taskId, this.blockFailSentRetryMaxTimes, failedBlockStatus.stream().map(x -> x.getShuffleServerInfo()).collect(Collectors.toSet())});
                    isFastFail = true;
                    break;
                }
                for (TrackingBlockStatus status : failedBlockStatus) {
                    StatusCode code = status.getStatusCode();
                    if (!STATUS_CODE_WITHOUT_BLOCK_RESEND.contains((Object)code)) continue;
                    LOG.error("Partial blocks for taskId: [{}] failed on the illegal status code: [{}] without resend on server: {}", new Object[]{this.taskId, code, status.getShuffleServerInfo()});
                    isFastFail = true;
                    break;
                }
                resendCandidates.addAll(failedBlockStatus);
            }
        }
        if (isFastFail) {
            for (Long blockId : failedBlockIds) {
                failedBlockStatus = failedTracker.getFailedBlockStatus(blockId);
                if (!CollectionUtils.isNotEmpty(failedBlockStatus)) continue;
                TrackingBlockStatus blockStatus = failedBlockStatus.get(0);
                blockStatus.getShuffleBlockInfo().executeCompletionCallback(true);
            }
            throw new RssSendFailedException("Errors on resending the blocks data to the remote shuffle-server.");
        }
        this.reassignAndResendBlocks(resendCandidates);
    }

    private void reassignOnPartitionNeedSplit(FailedBlockSendTracker failedTracker) {
        HashMap failurePartitionToServers = new HashMap();
        failedTracker.removeAllTrackedPartitions().forEach(partitionStatus -> {
            List servers = failurePartitionToServers.computeIfAbsent(partitionStatus.getPartitionId(), x -> new ArrayList());
            String serverId = partitionStatus.getShuffleServerInfo().getId();
            if (!servers.stream().map(x -> x.getServerId()).collect(Collectors.toSet()).contains(serverId)) {
                servers.add(new ReceivingFailureServer(serverId, StatusCode.SUCCESS));
            }
        });
        if (failurePartitionToServers.isEmpty()) {
            return;
        }
        HashMap<Integer, List<ReceivingFailureServer>> partitionToServersReassignList = new HashMap<Integer, List<ReceivingFailureServer>>();
        for (Map.Entry entry : failurePartitionToServers.entrySet()) {
            List failureServers;
            int partitionId = (Integer)entry.getKey();
            if (this.taskAttemptAssignment.updatePartitionSplitAssignment(partitionId, (failureServers = (List)entry.getValue()).stream().map(x -> ShuffleServerInfo.from(x.getServerId())).collect(Collectors.toList()))) continue;
            partitionToServersReassignList.put(partitionId, failureServers);
        }
        if (partitionToServersReassignList.isEmpty()) {
            LOG.info("[Partition split] Skip the following partition split request (maybe has been load balanced). partitionIds: {}", failurePartitionToServers.keySet());
            return;
        }
        this.doReassignOnBlockSendFailure(partitionToServersReassignList, true);
        LOG.info("========================= Partition Split Result =========================");
        for (Map.Entry entry : partitionToServersReassignList.entrySet()) {
            LOG.info("partitionId:{}. {} -> {}", new Object[]{entry.getKey(), ((List)entry.getValue()).stream().map(x -> x.getServerId()).collect(Collectors.toList()), this.taskAttemptAssignment.retrieve((Integer)entry.getKey())});
        }
        LOG.info("==========================================================================");
    }

    private void doReassignOnBlockSendFailure(Map<Integer, List<ReceivingFailureServer>> failurePartitionToServers, boolean partitionSplit) {
        LOG.info("Initiate reassignOnBlockSendFailure of taskId[{}]. partition split: {}. failure partition servers: {}. ", new Object[]{this.taskAttemptId, partitionSplit, failurePartitionToServers});
        String executorId = SparkEnv.get().executorId();
        long taskAttemptId = this.taskContext.taskAttemptId();
        int stageId = this.taskContext.stageId();
        int stageAttemptNum = this.taskContext.stageAttemptNumber();
        try {
            RssReassignOnBlockSendFailureRequest request = new RssReassignOnBlockSendFailureRequest(this.shuffleId, failurePartitionToServers, executorId, taskAttemptId, stageId, stageAttemptNum, partitionSplit);
            RssReassignOnBlockSendFailureResponse response = this.managerClientSupplier.get().reassignOnBlockSendFailure(request);
            if (response.getStatusCode() != StatusCode.SUCCESS) {
                String msg = String.format("Reassign failed. statusCode: %s, msg: %s", new Object[]{response.getStatusCode(), response.getMessage()});
                throw new RssException(msg);
            }
            MutableShuffleHandleInfo handle = MutableShuffleHandleInfo.fromProto(response.getHandle());
            this.taskAttemptAssignment.update(handle);
            HashMap reassignments = new HashMap();
            for (Map.Entry<Integer, List<ReceivingFailureServer>> entry : failurePartitionToServers.entrySet()) {
                int partitionId = entry.getKey();
                List<ShuffleServerInfo> servers = this.taskAttemptAssignment.retrieve(partitionId);
                reassignments.put(partitionId, servers.stream().map(x -> x.getId()).collect(Collectors.toList()));
            }
            LOG.info("Succeed to reassign that the latest assignment is {}", reassignments);
        }
        catch (Exception e) {
            throw new RssException("Errors on reassign on block send failure. failure partition->servers : " + failurePartitionToServers, e);
        }
    }

    private void reassignAndResendBlocks(Set<TrackingBlockStatus> blocks) {
        ArrayList<ShuffleBlockInfo> resendCandidates = Lists.newArrayList();
        Map<Integer, List<TrackingBlockStatus>> partitionedFailedBlocks = blocks.stream().collect(Collectors.groupingBy(d -> d.getShuffleBlockInfo().getPartitionId()));
        HashMap<Integer, List<ReceivingFailureServer>> failurePartitionToServers = new HashMap<Integer, List<ReceivingFailureServer>>();
        for (Map.Entry<Integer, List<TrackingBlockStatus>> entry : partitionedFailedBlocks.entrySet()) {
            int partitionId = entry.getKey();
            List<TrackingBlockStatus> partitionBlocks = entry.getValue();
            Map<ShuffleServerInfo, TrackingBlockStatus> serverBlocks = partitionBlocks.stream().collect(Collectors.groupingBy(d -> d.getShuffleServerInfo())).entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, x -> (TrackingBlockStatus)((List)x.getValue()).stream().findFirst().get()));
            for (Map.Entry<ShuffleServerInfo, TrackingBlockStatus> blockStatusEntry : serverBlocks.entrySet()) {
                List<ShuffleServerInfo> servers;
                ShuffleServerInfo replacement;
                String latestServerId;
                String serverId = blockStatusEntry.getKey().getId();
                if (!serverId.equals(latestServerId = (replacement = (servers = this.getPartitionAssignedServers(partitionId)).get(0)).getId())) continue;
                StatusCode code = blockStatusEntry.getValue().getStatusCode();
                failurePartitionToServers.computeIfAbsent(partitionId, x -> new ArrayList()).add(new ReceivingFailureServer(serverId, code));
            }
        }
        if (!failurePartitionToServers.isEmpty()) {
            this.doReassignOnBlockSendFailure(failurePartitionToServers, false);
        }
        for (TrackingBlockStatus blockStatus : blocks) {
            ShuffleBlockInfo block = blockStatus.getShuffleBlockInfo();
            List<ShuffleServerInfo> servers = this.getPartitionAssignedServers(block.getPartitionId());
            ShuffleServerInfo replacement = servers.get(0);
            if (blockStatus.getShuffleServerInfo().getId().equals(replacement.getId())) {
                LOG.warn("PartitionId:{} has the following assigned servers: {}. But currently the replacement server:{} is the same with previous one!", new Object[]{block.getPartitionId(), this.taskAttemptAssignment.list(block.getPartitionId()), replacement});
                throw new RssException("No available replacement server for: " + blockStatus.getShuffleServerInfo().getId());
            }
            this.clearFailedBlockState(block);
            ShuffleBlockInfo newBlock = block;
            newBlock.incrRetryCnt();
            newBlock.reassignShuffleServers(Arrays.asList(replacement));
            resendCandidates.add(newBlock);
        }
        this.processShuffleBlockInfos(resendCandidates);
        LOG.info("Failed blocks have been resent to data pusher queue since reassignment has been finished successfully");
    }

    private void clearFailedBlockState(ShuffleBlockInfo block) {
        this.shuffleManager.getBlockIdsFailedSendTracker(this.taskId).remove(block.getBlockId());
        block.getShuffleServerInfos().stream().forEach(s2 -> this.serverToPartitionToBlockIds.get(s2).get(block.getPartitionId()).remove(block.getBlockId()));
        this.blockIds.remove(block.getBlockId());
    }

    @VisibleForTesting
    protected void sendCommit() {
        ExecutorService executor = Executors.newSingleThreadExecutor();
        Future<Boolean> future = executor.submit(() -> this.shuffleWriteClient.sendCommit(this.shuffleServersForData, this.appId, this.shuffleId, this.numMaps));
        int maxWait = 5000;
        int currentWait = 200;
        long start = System.currentTimeMillis();
        while (!future.isDone()) {
            LOG.info("Wait commit to shuffle server for task[" + this.taskAttemptId + "] cost " + (System.currentTimeMillis() - start) + " ms");
            Uninterruptibles.sleepUninterruptibly(currentWait, TimeUnit.MILLISECONDS);
            currentWait = Math.min(currentWait * 2, maxWait);
        }
        try {
            if (!future.get().booleanValue()) {
                throw new RssException("Failed to commit task to shuffle server");
            }
        }
        catch (InterruptedException ie) {
            LOG.warn("Ignore the InterruptedException which should be caused by internal killed");
        }
        catch (Exception e) {
            throw new RssException("Exception happened when get commit status", e);
        }
        finally {
            executor.shutdown();
        }
    }

    @VisibleForTesting
    protected <T> int getPartition(T key) {
        int result = 0;
        if (this.shouldPartition) {
            result = this.partitioner.getPartition(key);
        }
        return result;
    }

    public Option<MapStatus> stop(boolean success) {
        try {
            if (success) {
                long start = System.currentTimeMillis();
                this.shuffleWriteClient.reportShuffleResult(this.serverToPartitionToBlockIds, this.appId, this.shuffleId, this.taskAttemptId, this.bitmapSplitNum, this.recordReportFailedShuffleservers, this.enableWriteFailureRetry);
                long reportDuration = System.currentTimeMillis() - start;
                LOG.info("Reported all shuffle result for shuffleId[{}] task[{}] with bitmapNum[{}] cost {} ms", new Object[]{this.shuffleId, this.taskAttemptId, this.bitmapSplitNum, reportDuration});
                this.shuffleWriteMetrics.incWriteTime(TimeUnit.MILLISECONDS.toNanos(reportDuration));
                BlockManagerId blockManagerId = BlockManagerId.apply((String)(this.appId + "_" + this.taskId), (String)DUMMY_HOST, (int)99999, (Option)Option.apply((Object)Long.toString(this.taskAttemptId)));
                MapStatus mapStatus = MapStatus.apply((BlockManagerId)blockManagerId, (long[])this.partitionLengthStatistic.toArray(), (long)this.taskAttemptId);
                Option option = Option.apply((Object)mapStatus);
                return option;
            }
            Option start = Option.empty();
            return start;
        }
        catch (Exception e) {
            if (this.enableWriteFailureRetry) {
                throw this.throwFetchFailedIfNecessary(e, this.recordReportFailedShuffleservers);
            }
            throw e;
        }
        finally {
            ShuffleManagerClient shuffleManagerClient;
            if (this.managerClientSupplier != null && (shuffleManagerClient = this.managerClientSupplier.get()) != null) {
                RssReportShuffleWriteMetricRequest.TaskShuffleWriteTimes writeTimes = new RssReportShuffleWriteMetricRequest.TaskShuffleWriteTimes(this.totalShuffleWriteMills, this.bufferManager.getCopyTime(), this.bufferManager.getSerializeTime(), this.bufferManager.getCompressTime(), this.bufferManager.getSortTime(), this.bufferManager.getRequireMemoryTime(), this.checkSendResultMills);
                RssReportShuffleWriteMetricResponse response = shuffleManagerClient.reportShuffleWriteMetric(new RssReportShuffleWriteMetricRequest(this.taskContext.stageId(), this.shuffleId, this.taskContext.taskAttemptId(), this.bufferManager.getShuffleServerPushCostTracker().toMetric(), writeTimes, this.isShuffleWriteFailed, this.shuffleWriteFailureReason, this.bufferManager.getUncompressedDataLen()));
                if (response.getStatusCode() != StatusCode.SUCCESS) {
                    LOG.error("Errors on reporting shuffle write metrics to driver");
                }
            }
            if (this.blockFailSentRetryEnabled) {
                if (success) {
                    if (CollectionUtils.isNotEmpty(this.shuffleManager.getFailedBlockIds(this.taskId))) {
                        LOG.error("Errors on stopping writer due to the remaining failed blockIds. This should not happen.");
                        return Option.empty();
                    }
                } else {
                    this.shuffleManager.getBlockIdsFailedSendTracker(this.taskId).clearAndReleaseBlockResources();
                }
            }
            if (this.bufferManager != null) {
                this.bufferManager.freeAllMemory();
                this.bufferManager.close();
            }
            if (this.shuffleManager != null) {
                this.shuffleManager.clearTaskMeta(this.taskId);
            }
        }
    }

    @VisibleForTesting
    Map<Integer, Set<Long>> getPartitionToBlockIds() {
        return this.serverToPartitionToBlockIds.values().stream().flatMap(s2 -> s2.entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (existingSet, newSet) -> {
            HashSet mergedSet = new HashSet(existingSet);
            mergedSet.addAll(newSet);
            return mergedSet;
        }));
    }

    @VisibleForTesting
    public WriteBufferManager getBufferManager() {
        return this.bufferManager;
    }

    private RssException throwFetchFailedIfNecessary(Exception e, Set<ShuffleServerInfo> reportFailuredServers) {
        if (e instanceof RssSendFailedException) {
            FailedBlockSendTracker blockIdsFailedSendTracker = this.shuffleManager.getBlockIdsFailedSendTracker(this.taskId);
            ArrayList<ShuffleServerInfo> shuffleServerInfos = Lists.newArrayList(blockIdsFailedSendTracker.getFaultyShuffleServers());
            shuffleServerInfos.addAll(reportFailuredServers);
            RssReportShuffleWriteFailureRequest req = new RssReportShuffleWriteFailureRequest(this.appId, this.shuffleId, this.taskContext.stageId(), this.taskContext.stageAttemptNumber(), shuffleServerInfos, e.getMessage());
            RssReportShuffleWriteFailureResponse response = this.managerClientSupplier.get().reportShuffleWriteFailure(req);
            if (response.getReSubmitWholeStage()) {
                LOG.warn(response.getMessage());
                FetchFailedException ffe = RssSparkShuffleUtils.createFetchFailedException(this.shuffleId, -1, this.taskContext.stageAttemptNumber(), e);
                throw new RssException((Throwable)ffe);
            }
        }
        throw new RssException(e);
    }

    @VisibleForTesting
    protected void enableBlockFailSentRetry() {
        this.blockFailSentRetryEnabled = true;
    }

    @VisibleForTesting
    protected void setBlockFailSentRetryMaxTimes(int blockFailSentRetryMaxTimes) {
        this.blockFailSentRetryMaxTimes = blockFailSentRetryMaxTimes;
    }

    @VisibleForTesting
    protected void setTaskId(String taskId) {
        this.taskId = taskId;
    }

    @VisibleForTesting
    protected Map<ShuffleServerInfo, Map<Integer, Set<Long>>> getServerToPartitionToBlockIds() {
        return this.serverToPartitionToBlockIds;
    }

    @VisibleForTesting
    protected RssShuffleManager getShuffleManager() {
        return this.shuffleManager;
    }

    public TaskAttemptAssignment getTaskAttemptAssignment() {
        return this.taskAttemptAssignment;
    }
}

