/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.kafka.listener;

import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.kafka.clients.consumer.Consumer;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.kafka.clients.consumer.ConsumerRecords;
import org.springframework.classify.BinaryExceptionClassifier;
import org.springframework.kafka.listener.CommonErrorHandler;
import org.springframework.kafka.listener.ListenerExecutionFailedException;
import org.springframework.kafka.listener.MessageListenerContainer;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

public class CommonDelegatingErrorHandler
implements CommonErrorHandler {
    private final CommonErrorHandler defaultErrorHandler;
    private final Map<Class<? extends Throwable>, CommonErrorHandler> delegates = new LinkedHashMap<Class<? extends Throwable>, CommonErrorHandler>();
    private boolean causeChainTraversing = false;
    private BinaryExceptionClassifier classifier = new BinaryExceptionClassifier(new HashMap());

    public CommonDelegatingErrorHandler(CommonErrorHandler defaultErrorHandler) {
        Assert.notNull((Object)defaultErrorHandler, (String)"'defaultErrorHandler' cannot be null");
        this.defaultErrorHandler = defaultErrorHandler;
    }

    public void setErrorHandlers(Map<Class<? extends Throwable>, CommonErrorHandler> delegates) {
        Assert.notNull(delegates, (String)"'delegates' cannot be null");
        this.delegates.clear();
        this.delegates.putAll(delegates);
        this.checkDelegatesAndUpdateClassifier(this.delegates);
    }

    public void setCauseChainTraversing(boolean causeChainTraversing) {
        this.causeChainTraversing = causeChainTraversing;
    }

    @Override
    public boolean remainingRecords() {
        return this.defaultErrorHandler.remainingRecords();
    }

    @Override
    public boolean seeksAfterHandling() {
        return this.defaultErrorHandler.seeksAfterHandling();
    }

    @Override
    public void clearThreadState() {
        this.defaultErrorHandler.clearThreadState();
        this.delegates.values().forEach(CommonErrorHandler::clearThreadState);
    }

    @Override
    public boolean isAckAfterHandle() {
        return this.defaultErrorHandler.isAckAfterHandle();
    }

    @Override
    public void setAckAfterHandle(boolean ack) {
        this.defaultErrorHandler.setAckAfterHandle(ack);
    }

    public void addDelegate(Class<? extends Throwable> throwable, CommonErrorHandler handler) {
        LinkedHashMap<Class<? extends Throwable>, CommonErrorHandler> delegatesToCheck = new LinkedHashMap<Class<? extends Throwable>, CommonErrorHandler>(this.delegates);
        delegatesToCheck.put(throwable, handler);
        this.checkDelegatesAndUpdateClassifier(delegatesToCheck);
        this.delegates.clear();
        this.delegates.putAll(delegatesToCheck);
    }

    private void checkDelegatesAndUpdateClassifier(Map<Class<? extends Throwable>, CommonErrorHandler> delegatesToCheck) {
        boolean remainingRecords = this.defaultErrorHandler.remainingRecords();
        boolean ackAfterHandle = this.defaultErrorHandler.isAckAfterHandle();
        boolean seeksAfterHandling = this.defaultErrorHandler.seeksAfterHandling();
        this.delegates.values().forEach(handler -> {
            Assert.isTrue((remainingRecords == handler.remainingRecords() ? 1 : 0) != 0, (String)"All delegates must return the same value when calling 'remainingRecords()'");
            Assert.isTrue((ackAfterHandle == handler.isAckAfterHandle() ? 1 : 0) != 0, (String)"All delegates must return the same value when calling 'isAckAfterHandle()'");
            Assert.isTrue((seeksAfterHandling == handler.seeksAfterHandling() ? 1 : 0) != 0, (String)"All delegates must return the same value when calling 'seeksAfterHandling()'");
        });
        this.updateClassifier(delegatesToCheck);
    }

    private void updateClassifier(Map<Class<? extends Throwable>, CommonErrorHandler> delegates) {
        Map<Class, Boolean> classifications = delegates.keySet().stream().map(commonErrorHandler -> Map.entry(commonErrorHandler, true)).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
        this.classifier = new BinaryExceptionClassifier(classifications);
    }

    @Override
    public void handleRemaining(Exception thrownException, List<ConsumerRecord<?, ?>> records, Consumer<?, ?> consumer, MessageListenerContainer container) {
        CommonErrorHandler handler = this.findDelegate(thrownException);
        if (handler != null) {
            handler.handleRemaining(thrownException, records, consumer, container);
        } else {
            this.defaultErrorHandler.handleRemaining(thrownException, records, consumer, container);
        }
    }

    @Override
    public void handleBatch(Exception thrownException, ConsumerRecords<?, ?> data, Consumer<?, ?> consumer, MessageListenerContainer container, Runnable invokeListener) {
        CommonErrorHandler handler = this.findDelegate(thrownException);
        if (handler != null) {
            handler.handleBatch(thrownException, data, consumer, container, invokeListener);
        } else {
            this.defaultErrorHandler.handleBatch(thrownException, data, consumer, container, invokeListener);
        }
    }

    @Override
    public void handleOtherException(Exception thrownException, Consumer<?, ?> consumer, MessageListenerContainer container, boolean batchListener) {
        CommonErrorHandler handler = this.findDelegate(thrownException);
        if (handler != null) {
            handler.handleOtherException(thrownException, consumer, container, batchListener);
        } else {
            this.defaultErrorHandler.handleOtherException(thrownException, consumer, container, batchListener);
        }
    }

    @Nullable
    private CommonErrorHandler findDelegate(Throwable thrownException) {
        Throwable cause = this.findCause(thrownException);
        if (cause != null) {
            Class<?> causeClass = cause.getClass();
            for (Map.Entry<Class<? extends Throwable>, CommonErrorHandler> entry : this.delegates.entrySet()) {
                if (!entry.getKey().isAssignableFrom(causeClass)) continue;
                return entry.getValue();
            }
        }
        return null;
    }

    @Nullable
    private Throwable findCause(Throwable thrownException) {
        if (this.causeChainTraversing) {
            return this.traverseCauseChain(thrownException);
        }
        return this.shallowTraverseCauseChain(thrownException);
    }

    @Nullable
    private Throwable shallowTraverseCauseChain(Throwable thrownException) {
        Throwable cause = thrownException;
        if (cause instanceof ListenerExecutionFailedException) {
            cause = thrownException.getCause();
        }
        return cause;
    }

    @Nullable
    private Throwable traverseCauseChain(Throwable thrownException) {
        Throwable cause;
        for (cause = thrownException; cause != null && cause.getCause() != null; cause = cause.getCause()) {
            if (!this.classifier.classify(cause).booleanValue()) continue;
            return cause;
        }
        return cause;
    }
}

