/*
 * Decompiled with CFR 0.152.
 */
package org.apache.celeborn.common.network.sasl;

import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.RpcRequest;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.sasl.CelebornSaslServer;
import org.apache.celeborn.common.network.sasl.SaslUtils;
import org.apache.celeborn.common.network.sasl.SecretRegistry;
import org.apache.celeborn.common.network.server.AbstractAuthRpcHandler;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.network.util.TransportConf;
import org.apache.celeborn.common.protocol.PbSaslRequest;
import org.apache.celeborn.shaded.io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SaslRpcHandler
extends AbstractAuthRpcHandler {
    private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);
    private final TransportConf conf;
    private final Channel channel;
    private final SecretRegistry secretRegistry;
    private CelebornSaslServer saslServer;

    public SaslRpcHandler(TransportConf conf, Channel channel, BaseMessageHandler delegate, SecretRegistry secretRegistry) {
        super(delegate);
        this.conf = conf;
        this.channel = channel;
        this.secretRegistry = secretRegistry;
        this.saslServer = null;
    }

    @Override
    public boolean checkRegistered() {
        return this.delegate.checkRegistered();
    }

    @Override
    public boolean doAuthChallenge(TransportClient client, RequestMessage message, RpcResponseCallback callback) {
        if (this.saslServer == null || !this.saslServer.isComplete()) {
            PbSaslRequest saslMessage;
            RpcRequest rpcRequest = (RpcRequest)message;
            try {
                TransportMessage pbMsg = TransportMessage.fromByteBuffer(message.body().nioByteBuffer());
                saslMessage = (PbSaslRequest)pbMsg.getParsedPayload();
            }
            catch (IOException e) {
                logger.error("Error while parsing Sasl Message with RPC id {}", (Object)rpcRequest.requestId, (Object)e);
                callback.onFailure(e);
                return false;
            }
            if (this.saslServer == null) {
                this.saslServer = new CelebornSaslServer("DIGEST-MD5", SaslUtils.DEFAULT_SASL_SERVER_PROPS, new CelebornSaslServer.DigestCallbackHandler(client, this.secretRegistry));
            }
            byte[] response = this.saslServer.response(saslMessage.getPayload().toByteArray());
            callback.onSuccess(ByteBuffer.wrap(response));
        }
        if (this.saslServer.isComplete()) {
            logger.debug("SASL authentication successful for channel {}", (Object)client);
            this.complete();
            return true;
        }
        return false;
    }

    @Override
    public void channelInactive(TransportClient client) {
        super.channelInactive(client);
        this.cleanup();
    }

    private void complete() {
        this.cleanup();
    }

    public void cleanup() {
        if (null != this.saslServer) {
            try {
                this.saslServer.dispose();
            }
            catch (RuntimeException e) {
                logger.error("Error while disposing SASL server", (Throwable)e);
            }
            finally {
                this.saslServer = null;
            }
        }
    }
}

