/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.network.shuffle;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.util.concurrent.EventExecutorGroup;
import java.nio.ByteBuffer;
import java.util.List;
import org.apache.spark.internal.SparkLogger;
import org.apache.spark.internal.SparkLoggerFactory;
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.protocol.MessageDecoder;
import org.apache.spark.network.protocol.RequestMessage;
import org.apache.spark.network.protocol.RpcRequest;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.server.TransportRequestHandler;
import org.apache.spark.network.shuffle.ExternalBlockHandler;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.util.IOMode;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

public class ShuffleTransportContext
extends TransportContext {
    private static final SparkLogger logger = SparkLoggerFactory.getLogger(ShuffleTransportContext.class);
    private static final ShuffleMessageDecoder SHUFFLE_DECODER = new ShuffleMessageDecoder(MessageDecoder.INSTANCE);
    private final EventLoopGroup finalizeWorkers;

    public ShuffleTransportContext(TransportConf conf, ExternalBlockHandler rpcHandler, boolean closeIdleConnections) {
        this(conf, rpcHandler, closeIdleConnections, false);
    }

    public ShuffleTransportContext(TransportConf conf, RpcHandler rpcHandler, boolean closeIdleConnections, boolean isClientOnly) {
        super(conf, rpcHandler, closeIdleConnections, isClientOnly);
        if ("shuffle".equalsIgnoreCase(conf.getModuleName()) && conf.separateFinalizeShuffleMerge()) {
            this.finalizeWorkers = NettyUtils.createEventLoop((IOMode)IOMode.valueOf((String)conf.ioMode()), (int)conf.finalizeShuffleMergeHandlerThreads(), (String)"shuffle-finalize-merge-handler");
            logger.info("finalize shuffle merged workers created");
        } else {
            this.finalizeWorkers = null;
        }
    }

    public TransportChannelHandler initializePipeline(SocketChannel channel, boolean isClient) {
        TransportChannelHandler ch = super.initializePipeline(channel, isClient);
        this.addHandlerToPipeline(channel, ch);
        return ch;
    }

    public TransportChannelHandler initializePipeline(SocketChannel channel, RpcHandler channelRpcHandler, boolean isClient) {
        TransportChannelHandler ch = super.initializePipeline(channel, channelRpcHandler, isClient);
        this.addHandlerToPipeline(channel, ch);
        return ch;
    }

    private void addHandlerToPipeline(SocketChannel channel, TransportChannelHandler transportChannelHandler) {
        if (this.finalizeWorkers != null) {
            channel.pipeline().addLast((EventExecutorGroup)this.finalizeWorkers, "finalizeHandler", (ChannelHandler)new FinalizedHandler(transportChannelHandler.getRequestHandler()));
        }
    }

    protected MessageToMessageDecoder<ByteBuf> getDecoder() {
        return this.finalizeWorkers == null ? super.getDecoder() : SHUFFLE_DECODER;
    }

    static class FinalizedHandler
    extends SimpleChannelInboundHandler<RpcRequestInternal> {
        private static final SparkLogger logger = SparkLoggerFactory.getLogger(FinalizedHandler.class);
        public static final String HANDLER_NAME = "finalizeHandler";
        private final TransportRequestHandler transportRequestHandler;

        public boolean acceptInboundMessage(Object msg) throws Exception {
            if (msg instanceof RpcRequestInternal) {
                RpcRequestInternal rpcRequestInternal = (RpcRequestInternal)msg;
                return rpcRequestInternal.messageType == BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE;
            }
            return false;
        }

        FinalizedHandler(TransportRequestHandler transportRequestHandler) {
            this.transportRequestHandler = transportRequestHandler;
        }

        protected void channelRead0(ChannelHandlerContext channelHandlerContext, RpcRequestInternal req) throws Exception {
            if (logger.isTraceEnabled()) {
                logger.trace("Finalize shuffle req from {} for rpc request {}", (Object)NettyUtils.getRemoteAddress((Channel)channelHandlerContext.channel()), (Object)req.rpcRequest.requestId);
            }
            this.transportRequestHandler.handle((RequestMessage)req.rpcRequest);
        }
    }

    @ChannelHandler.Sharable
    static class ShuffleMessageDecoder
    extends MessageToMessageDecoder<ByteBuf> {
        private final MessageDecoder delegate;

        ShuffleMessageDecoder(MessageDecoder delegate) {
            this.delegate = delegate;
        }

        protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf byteBuf, List<Object> list) throws Exception {
            RpcRequest req;
            ByteBuffer buffer;
            byte type;
            this.delegate.decode(channelHandlerContext, byteBuf, list);
            Object msg = list.get(list.size() - 1);
            if (msg instanceof RpcRequest && (type = Unpooled.wrappedBuffer((ByteBuffer)(buffer = (req = (RpcRequest)msg).body().nioByteBuffer())).readByte()) == BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE.id()) {
                list.remove(list.size() - 1);
                RpcRequestInternal rpcRequestInternal = new RpcRequestInternal(BlockTransferMessage.Type.FINALIZE_SHUFFLE_MERGE, req);
                logger.trace("Created internal rpc request msg with rpcId {} for finalize merge req", (Object)req.requestId);
                list.add(rpcRequestInternal);
            }
        }
    }

    record RpcRequestInternal(BlockTransferMessage.Type messageType, RpcRequest rpcRequest) {
    }
}

