package org.apache.kafka.streams.processor.assignment.assignors;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.assignment.ApplicationState;
import org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment;
import org.apache.kafka.streams.processor.assignment.KafkaStreamsState;
import org.apache.kafka.streams.processor.assignment.ProcessId;
import org.apache.kafka.streams.processor.assignment.TaskAssignmentUtils;
import org.apache.kafka.streams.processor.assignment.TaskAssignor;
import org.apache.kafka.streams.processor.assignment.TaskInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.class */
public class StickyTaskAssignor implements TaskAssignor {
    private static final Logger LOG = LoggerFactory.getLogger(StickyTaskAssignor.class);
    public static final int DEFAULT_STICKY_TRAFFIC_COST = 1;
    public static final int DEFAULT_STICKY_NON_OVERLAP_COST = 10;
    private final boolean mustPreserveActiveTaskAssignment;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor$AssignmentState.class */
    public static class AssignmentState {
        private final Map<ProcessId, KafkaStreamsState> clients;
        private final Map<TaskId, ProcessId> previousActiveAssignment;
        private final Map<TaskId, Set<ProcessId>> previousStandbyAssignment;
        private final TaskPairs taskPairs;
        private Map<TaskId, Set<ProcessId>> newTaskLocations;
        private Map<ProcessId, KafkaStreamsAssignment> newAssignments;

        private AssignmentState(ApplicationState applicationState, Map<ProcessId, KafkaStreamsState> map, Map<TaskId, ProcessId> map2, Map<TaskId, Set<ProcessId>> map3) {
            this.clients = map;
            this.previousActiveAssignment = Collections.unmodifiableMap(map2);
            this.previousStandbyAssignment = Collections.unmodifiableMap(map3);
            int size = applicationState.allTasks().size();
            this.taskPairs = new TaskPairs((size * (size - 1)) / 2);
            this.newTaskLocations = (Map) map2.keySet().stream().collect(Collectors.toMap(Function.identity(), taskId -> {
                return new HashSet();
            }));
            this.newAssignments = (Map) map.values().stream().collect(Collectors.toMap((v0) -> {
                return v0.processId();
            }, kafkaStreamsState -> {
                return KafkaStreamsAssignment.of(kafkaStreamsState.processId(), new HashSet());
            }));
        }

        private void finalizeAssignment(TaskId taskId, ProcessId processId, KafkaStreamsAssignment.AssignedTask.Type type) {
            this.taskPairs.addPairs(taskId, this.newAssignments.get(processId).tasks().keySet());
            this.newAssignments.get(processId).assignTask(new KafkaStreamsAssignment.AssignedTask(taskId, type));
            this.newTaskLocations.computeIfAbsent(taskId, taskId2 -> {
                return new HashSet();
            }).add(processId);
        }

        private void processOptimizedAssignments(Map<ProcessId, KafkaStreamsAssignment> map) {
            HashMap hashMap = new HashMap();
            Iterator<Map.Entry<ProcessId, KafkaStreamsAssignment>> it = map.entrySet().iterator();
            while (it.hasNext()) {
                ProcessId key = it.next().getKey();
                Iterator it2 = new HashSet(map.get(key).tasks().values()).iterator();
                while (it2.hasNext()) {
                    ((Set) hashMap.computeIfAbsent(((KafkaStreamsAssignment.AssignedTask) it2.next()).id(), taskId -> {
                        return new HashSet();
                    })).add(key);
                }
            }
            this.newTaskLocations = hashMap;
            this.newAssignments = map;
        }

        private boolean hasRoomForActiveTask(ProcessId processId, int i) {
            return ((Set) this.newAssignments.computeIfAbsent(processId, processId2 -> {
                return KafkaStreamsAssignment.of(processId, new HashSet());
            }).tasks().values().stream().filter(assignedTask -> {
                return assignedTask.type() == KafkaStreamsAssignment.AssignedTask.Type.ACTIVE;
            }).collect(Collectors.toSet())).size() < this.clients.get(processId).numProcessingThreads() * i;
        }

        private ProcessId findBestClientForTask(TaskId taskId, Set<ProcessId> set) {
            if (set.size() == 1) {
                return set.iterator().next();
            }
            ProcessId findLeastLoadedClientWithPreviousActiveOrStandbyTask = findLeastLoadedClientWithPreviousActiveOrStandbyTask(taskId, set);
            if (findLeastLoadedClientWithPreviousActiveOrStandbyTask == null) {
                return findLeastLoadedClient(taskId, set);
            }
            if (!shouldBalanceLoad(findLeastLoadedClientWithPreviousActiveOrStandbyTask)) {
                return findLeastLoadedClientWithPreviousActiveOrStandbyTask;
            }
            ProcessId findLeastLoadedClientWithPreviousStandbyTask = findLeastLoadedClientWithPreviousStandbyTask(taskId, set);
            return (findLeastLoadedClientWithPreviousStandbyTask == null || shouldBalanceLoad(findLeastLoadedClientWithPreviousStandbyTask)) ? findLeastLoadedClient(taskId, set) : findLeastLoadedClientWithPreviousStandbyTask;
        }

        private Set<ProcessId> findClientsWithoutAssignedTask(TaskId taskId) {
            Set<ProcessId> set = this.newTaskLocations.get(taskId);
            return (Set) this.clients.values().stream().map((v0) -> {
                return v0.processId();
            }).filter(processId -> {
                return !set.contains(processId);
            }).collect(Collectors.toSet());
        }

        private double clientLoad(ProcessId processId) {
            return this.newAssignments.get(processId).tasks().size() / this.clients.get(processId).numProcessingThreads();
        }

        private ProcessId findLeastLoadedClient(TaskId taskId, Set<ProcessId> set) {
            ProcessId processId = null;
            for (ProcessId processId2 : set) {
                double clientLoad = clientLoad(processId2);
                if (clientLoad == 0.0d) {
                    return processId2;
                }
                if (processId == null || clientLoad < clientLoad(processId)) {
                    if (this.taskPairs.hasNewPair(taskId, (Set) this.newAssignments.get(processId2).tasks().values().stream().map((v0) -> {
                        return v0.id();
                    }).collect(Collectors.toSet()))) {
                        processId = processId2;
                    }
                }
            }
            if (processId != null) {
                return processId;
            }
            for (ProcessId processId3 : set) {
                double clientLoad2 = clientLoad(processId3);
                if (processId == null || clientLoad2 < clientLoad(processId)) {
                    processId = processId3;
                }
            }
            return processId;
        }

        private ProcessId findLeastLoadedClientWithPreviousActiveOrStandbyTask(TaskId taskId, Set<ProcessId> set) {
            ProcessId processId = this.previousActiveAssignment.get(taskId);
            return (processId == null || !set.contains(processId)) ? findLeastLoadedClientWithPreviousStandbyTask(taskId, set) : processId;
        }

        private ProcessId findLeastLoadedClientWithPreviousStandbyTask(TaskId taskId, Set<ProcessId> set) {
            HashSet hashSet = new HashSet(this.previousStandbyAssignment.getOrDefault(taskId, new HashSet()));
            hashSet.retainAll(set);
            return findLeastLoadedClient(taskId, hashSet);
        }

        private boolean shouldBalanceLoad(ProcessId processId) {
            double clientLoad = clientLoad(processId);
            if (clientLoad < 1.0d) {
                return false;
            }
            Iterator<ProcessId> it = this.clients.keySet().iterator();
            while (it.hasNext()) {
                if (clientLoad(it.next()) < clientLoad) {
                    return true;
                }
            }
            return false;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor$TaskPair.class */
    public static class TaskPair {
        private final TaskId task1;
        private final TaskId task2;

        TaskPair(TaskId taskId, TaskId taskId2) {
            this.task1 = taskId;
            this.task2 = taskId2;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            TaskPair taskPair = (TaskPair) obj;
            return Objects.equals(this.task1, taskPair.task1) && Objects.equals(this.task2, taskPair.task2);
        }

        public int hashCode() {
            return Objects.hash(this.task1, this.task2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor$TaskPairs.class */
    public static class TaskPairs {
        private final Set<TaskPair> pairs;
        private final int maxPairs;

        public TaskPairs(int i) {
            this.maxPairs = i;
            this.pairs = new HashSet(i);
        }

        public boolean hasNewPair(TaskId taskId, Set<TaskId> set) {
            if (this.pairs.size() == this.maxPairs) {
                return false;
            }
            Iterator<TaskId> it = set.iterator();
            while (it.hasNext()) {
                if (!this.pairs.contains(pair(taskId, it.next()))) {
                    return true;
                }
            }
            return false;
        }

        public void addPairs(TaskId taskId, Set<TaskId> set) {
            Iterator<TaskId> it = set.iterator();
            while (it.hasNext()) {
                this.pairs.add(pair(it.next(), taskId));
            }
        }

        public TaskPair pair(TaskId taskId, TaskId taskId2) {
            return taskId.compareTo(taskId2) < 0 ? new TaskPair(taskId, taskId2) : new TaskPair(taskId2, taskId);
        }
    }

    public StickyTaskAssignor() {
        this(false);
    }

    public StickyTaskAssignor(boolean z) {
        this.mustPreserveActiveTaskAssignment = z;
    }

    @Override // org.apache.kafka.streams.processor.assignment.TaskAssignor
    public TaskAssignor.TaskAssignment assign(ApplicationState applicationState) {
        Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = applicationState.kafkaStreamsStates(false);
        AssignmentState assignmentState = new AssignmentState(applicationState, kafkaStreamsStates, mapPreviousActiveTasks(kafkaStreamsStates), mapPreviousStandbyTasks(kafkaStreamsStates));
        assignActive(applicationState, kafkaStreamsStates.values(), assignmentState, this.mustPreserveActiveTaskAssignment);
        optimizeActive(applicationState, assignmentState);
        assignStandby(applicationState, assignmentState);
        optimizeStandby(applicationState, assignmentState);
        Map<ProcessId, KafkaStreamsAssignment> map = assignmentState.newAssignments;
        if (this.mustPreserveActiveTaskAssignment && !map.isEmpty()) {
            ProcessId key = map.entrySet().iterator().next().getKey();
            map.put(key, map.get(key).withFollowupRebalance(Instant.ofEpochMilli(0L)));
        }
        return new TaskAssignor.TaskAssignment(map.values());
    }

    private void optimizeActive(ApplicationState applicationState, AssignmentState assignmentState) {
        if (this.mustPreserveActiveTaskAssignment) {
            return;
        }
        Map<ProcessId, KafkaStreamsAssignment> map = assignmentState.newAssignments;
        TaskAssignmentUtils.optimizeRackAwareActiveTasks(TaskAssignmentUtils.RackAwareOptimizationParams.of(applicationState).withTrafficCostOverride(applicationState.assignmentConfigs().rackAwareTrafficCost().orElse(1)).withNonOverlapCostOverride(applicationState.assignmentConfigs().rackAwareNonOverlapCost().orElse(10)).forStatefulTasks(), map);
        TaskAssignmentUtils.optimizeRackAwareActiveTasks(TaskAssignmentUtils.RackAwareOptimizationParams.of(applicationState).forStatelessTasks().withTrafficCostOverride(1).withNonOverlapCostOverride(0), map);
        assignmentState.processOptimizedAssignments(map);
    }

    private void optimizeStandby(ApplicationState applicationState, AssignmentState assignmentState) {
        if (applicationState.assignmentConfigs().numStandbyReplicas() > 0 && !this.mustPreserveActiveTaskAssignment) {
            Map<ProcessId, KafkaStreamsAssignment> map = assignmentState.newAssignments;
            TaskAssignmentUtils.optimizeRackAwareStandbyTasks(TaskAssignmentUtils.RackAwareOptimizationParams.of(applicationState).withTrafficCostOverride(applicationState.assignmentConfigs().rackAwareTrafficCost().orElse(1)).withNonOverlapCostOverride(applicationState.assignmentConfigs().rackAwareNonOverlapCost().orElse(10)), map);
            assignmentState.processOptimizedAssignments(map);
        }
    }

    private static void assignActive(ApplicationState applicationState, Collection<KafkaStreamsState> collection, AssignmentState assignmentState, boolean z) {
        int computeTotalProcessingThreads = computeTotalProcessingThreads(collection);
        Set<TaskId> keySet = applicationState.allTasks().keySet();
        int size = keySet.size() / computeTotalProcessingThreads;
        HashSet hashSet = new HashSet(keySet);
        for (TaskId taskId : assignmentState.previousActiveAssignment.keySet()) {
            ProcessId processId = assignmentState.previousActiveAssignment.get(taskId);
            if (keySet.contains(taskId) && (z || assignmentState.hasRoomForActiveTask(processId, size))) {
                assignmentState.finalizeAssignment(taskId, processId, KafkaStreamsAssignment.AssignedTask.Type.ACTIVE);
                hashSet.remove(taskId);
            }
        }
        Iterator it = hashSet.iterator();
        while (it.hasNext()) {
            TaskId taskId2 = (TaskId) it.next();
            Iterator<ProcessId> it2 = assignmentState.previousStandbyAssignment.getOrDefault(taskId2, new HashSet()).iterator();
            while (true) {
                if (it2.hasNext()) {
                    ProcessId next = it2.next();
                    if (assignmentState.hasRoomForActiveTask(next, size)) {
                        assignmentState.finalizeAssignment(taskId2, next, KafkaStreamsAssignment.AssignedTask.Type.ACTIVE);
                        it.remove();
                        break;
                    }
                }
            }
        }
        ArrayList<TaskId> arrayList = new ArrayList(hashSet);
        Collections.sort(arrayList);
        for (TaskId taskId3 : arrayList) {
            assignmentState.finalizeAssignment(taskId3, assignmentState.findBestClientForTask(taskId3, (Set) collection.stream().map((v0) -> {
                return v0.processId();
            }).collect(Collectors.toSet())), KafkaStreamsAssignment.AssignedTask.Type.ACTIVE);
        }
    }

    private static void assignStandby(ApplicationState applicationState, AssignmentState assignmentState) {
        Set<TaskInfo> set = (Set) applicationState.allTasks().values().stream().filter(taskInfo -> {
            return taskInfo.topicPartitions().stream().anyMatch((v0) -> {
                return v0.isChangelog();
            });
        }).collect(Collectors.toSet());
        int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas();
        for (TaskInfo taskInfo2 : set) {
            int i = 0;
            while (true) {
                if (i < numStandbyReplicas) {
                    Set<ProcessId> findClientsWithoutAssignedTask = assignmentState.findClientsWithoutAssignedTask(taskInfo2.id());
                    if (findClientsWithoutAssignedTask.isEmpty()) {
                        LOG.warn("Unable to assign {} of {} standby tasks for task [{}]. There is not enough available capacity. You should increase the number of threads and/or application instances to maintain the requested number of standby replicas.", new Object[]{Integer.valueOf(numStandbyReplicas - i), Integer.valueOf(numStandbyReplicas), taskInfo2.id()});
                        break;
                    } else {
                        assignmentState.finalizeAssignment(taskInfo2.id(), assignmentState.findBestClientForTask(taskInfo2.id(), findClientsWithoutAssignedTask), KafkaStreamsAssignment.AssignedTask.Type.STANDBY);
                        i++;
                    }
                }
            }
        }
    }

    private static Map<TaskId, ProcessId> mapPreviousActiveTasks(Map<ProcessId, KafkaStreamsState> map) {
        HashMap hashMap = new HashMap();
        for (KafkaStreamsState kafkaStreamsState : map.values()) {
            Iterator<TaskId> it = kafkaStreamsState.previousActiveTasks().iterator();
            while (it.hasNext()) {
                hashMap.put(it.next(), kafkaStreamsState.processId());
            }
        }
        return hashMap;
    }

    private static Map<TaskId, Set<ProcessId>> mapPreviousStandbyTasks(Map<ProcessId, KafkaStreamsState> map) {
        HashMap hashMap = new HashMap();
        for (KafkaStreamsState kafkaStreamsState : map.values()) {
            for (TaskId taskId : kafkaStreamsState.previousStandbyTasks()) {
                hashMap.computeIfAbsent(taskId, taskId2 -> {
                    return new HashSet();
                });
                ((Set) hashMap.get(taskId)).add(kafkaStreamsState.processId());
            }
        }
        return hashMap;
    }

    private static int computeTotalProcessingThreads(Collection<KafkaStreamsState> collection) {
        int i = 0;
        Iterator<KafkaStreamsState> it = collection.iterator();
        while (it.hasNext()) {
            i += it.next().numProcessingThreads();
        }
        return i;
    }
}
