package org.apache.kafka.streams.processor.internals.assignment;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.UUID;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.kafka.common.Cluster;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.TopicPartitionInfo;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.InternalTopicManager;
import org.apache.kafka.streams.processor.internals.TopologyMetadata;
import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.class */
public class RackAwareTaskAssignor {
    public static final int STATELESS_TRAFFIC_COST = 1;
    public static final int STATELESS_NON_OVERLAP_COST = 0;
    private static final Logger log = LoggerFactory.getLogger(RackAwareTaskAssignor.class);
    private static final int SOURCE_ID = -1;
    private static final int STANDBY_OPTIMIZER_MAX_ITERATION = 4;
    private final Cluster fullMetadata;
    private final Map<TaskId, Set<TopicPartition>> partitionsForTask;
    private final Map<TaskId, Set<TopicPartition>> changelogPartitionsForTask;
    private final AssignorConfiguration.AssignmentConfigs assignmentConfigs;
    private final InternalTopicManager internalTopicManager;
    private final boolean validClientRack;
    private final Time time;
    private Boolean canEnable = null;
    private final Map<TopicPartition, Set<String>> racksForPartition = new HashMap();
    private final Map<UUID, String> racksForProcess = new HashMap();

    @FunctionalInterface
    /* loaded from: input_file:org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor$MoveStandbyTaskPredicate.class */
    public interface MoveStandbyTaskPredicate {
        boolean canMove(ClientState clientState, ClientState clientState2, TaskId taskId, Map<UUID, ClientState> map);
    }

    public RackAwareTaskAssignor(Cluster cluster, Map<TaskId, Set<TopicPartition>> map, Map<TaskId, Set<TopicPartition>> map2, Map<TopologyMetadata.Subtopology, Set<TaskId>> map3, Map<UUID, Map<String, Optional<String>>> map4, InternalTopicManager internalTopicManager, AssignorConfiguration.AssignmentConfigs assignmentConfigs, Time time) {
        this.fullMetadata = cluster;
        this.partitionsForTask = map;
        this.changelogPartitionsForTask = map2;
        this.internalTopicManager = internalTopicManager;
        this.assignmentConfigs = assignmentConfigs;
        this.time = (Time) Objects.requireNonNull(time, "Time was not specified");
        this.validClientRack = validateClientRack(map4);
    }

    public boolean validClientRack() {
        return this.validClientRack;
    }

    public synchronized boolean canEnableRackAwareAssignor() {
        if ("none".equals(this.assignmentConfigs.rackAwareAssignmentStrategy)) {
            return false;
        }
        if (this.canEnable != null) {
            return this.canEnable.booleanValue();
        }
        this.canEnable = Boolean.valueOf(this.validClientRack && validateTopicPartitionRack(false));
        if (this.assignmentConfigs.numStandbyReplicas == 0 || !this.canEnable.booleanValue()) {
            return this.canEnable.booleanValue();
        }
        this.canEnable = Boolean.valueOf(validateTopicPartitionRack(true));
        return this.canEnable.booleanValue();
    }

    public boolean populateTopicsToDescribe(Set<String> set, boolean z) {
        if (z) {
            this.changelogPartitionsForTask.values().stream().flatMap((v0) -> {
                return v0.stream();
            }).forEach(topicPartition -> {
                set.add(topicPartition.topic());
            });
            return true;
        }
        Iterator<Set<TopicPartition>> it = this.partitionsForTask.values().iterator();
        while (it.hasNext()) {
            for (TopicPartition topicPartition2 : it.next()) {
                PartitionInfo partition = this.fullMetadata.partition(topicPartition2);
                if (partition == null) {
                    log.error("TopicPartition {} doesn't exist in cluster", topicPartition2);
                    return false;
                }
                Node[] replicas = partition.replicas();
                if (replicas == null || replicas.length == 0) {
                    set.add(topicPartition2.topic());
                } else {
                    for (Node node : replicas) {
                        if (!node.hasRack()) {
                            log.warn("Node {} for topic partition {} doesn't have rack", node, topicPartition2);
                            return false;
                        }
                        this.racksForPartition.computeIfAbsent(topicPartition2, topicPartition3 -> {
                            return new HashSet();
                        }).add(node.rack());
                    }
                }
            }
        }
        return true;
    }

    private boolean validateTopicPartitionRack(boolean z) {
        HashSet hashSet = new HashSet();
        if (!populateTopicsToDescribe(hashSet, z)) {
            return false;
        }
        if (hashSet.isEmpty()) {
            return true;
        }
        log.info("Fetching PartitionInfo for topics {}", hashSet);
        try {
            Map<String, List<TopicPartitionInfo>> topicPartitionInfo = this.internalTopicManager.getTopicPartitionInfo(hashSet);
            if (hashSet.size() > topicPartitionInfo.size()) {
                hashSet.removeAll(topicPartitionInfo.keySet());
                log.error("Failed to describe topic for {}", hashSet);
                return false;
            }
            for (Map.Entry<String, List<TopicPartitionInfo>> entry : topicPartitionInfo.entrySet()) {
                for (TopicPartitionInfo topicPartitionInfo2 : entry.getValue()) {
                    int partition = topicPartitionInfo2.partition();
                    List<Node> replicas = topicPartitionInfo2.replicas();
                    if (replicas == null || replicas.isEmpty()) {
                        log.error("No replicas found for topic partition {}: {}", entry.getKey(), Integer.valueOf(partition));
                        return false;
                    }
                    TopicPartition topicPartition = new TopicPartition(entry.getKey(), partition);
                    for (Node node : replicas) {
                        if (!node.hasRack()) {
                            return false;
                        }
                        this.racksForPartition.computeIfAbsent(topicPartition, topicPartition2 -> {
                            return new HashSet();
                        }).add(node.rack());
                    }
                }
            }
            return true;
        } catch (Exception e) {
            log.error("Failed to describe topics {}", hashSet, e);
            return false;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private boolean validateClientRack(Map<UUID, Map<String, Optional<String>>> map) {
        if (map == null) {
            return false;
        }
        for (Map.Entry<UUID, Map<String, Optional<String>>> entry : map.entrySet()) {
            UUID key = entry.getKey();
            KeyValue keyValue = null;
            for (Map.Entry<String, Optional<String>> entry2 : entry.getValue().entrySet()) {
                if (!entry2.getValue().isPresent()) {
                    log.error(String.format("RackId doesn't exist for process %s and consumer %s", key, entry2.getKey()));
                    return false;
                }
                if (keyValue == null) {
                    keyValue = KeyValue.pair(entry2.getKey(), entry2.getValue().get());
                } else if (!((String) keyValue.value).equals(entry2.getValue().get())) {
                    log.error(String.format("Consumers %s and %s for same process %s has different rackId %s and %s. File a ticket for this bug", keyValue.key, entry2.getKey(), entry.getKey(), keyValue.value, entry2.getValue().get()));
                    return false;
                }
            }
            if (keyValue == null) {
                log.error(String.format("RackId doesn't exist for process %s", key));
                return false;
            }
            this.racksForProcess.put(entry.getKey(), keyValue.value);
        }
        return true;
    }

    public Map<UUID, String> racksForProcess() {
        return Collections.unmodifiableMap(this.racksForProcess);
    }

    public Map<TopicPartition, Set<String>> racksForPartition() {
        return Collections.unmodifiableMap(this.racksForPartition);
    }

    private int getCost(TaskId taskId, UUID uuid, boolean z, int i, int i2, boolean z2) {
        String str = this.racksForProcess.get(uuid);
        if (str == null) {
            throw new IllegalStateException("Client " + uuid + " doesn't have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
        }
        Set<TopicPartition> set = z2 ? this.changelogPartitionsForTask.get(taskId) : this.partitionsForTask.get(taskId);
        if (set == null || set.isEmpty()) {
            throw new IllegalStateException("Task " + taskId + " has no TopicPartitions");
        }
        int i3 = 0;
        for (TopicPartition topicPartition : set) {
            Set<String> set2 = this.racksForPartition.get(topicPartition);
            if (set2 == null || set2.isEmpty()) {
                throw new IllegalStateException("TopicPartition " + topicPartition + " has no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
            }
            if (!set2.contains(str)) {
                i3 += i;
            }
        }
        if (!z) {
            i3 += i2;
        }
        return i3;
    }

    private static int getSinkNodeID(List<UUID> list, List<TaskId> list2) {
        return list.size() + list2.size();
    }

    private static int getClientNodeId(List<TaskId> list, int i) {
        return i + list.size();
    }

    private static int getClientIndex(List<TaskId> list, int i) {
        return i - list.size();
    }

    long activeTasksCost(SortedSet<TaskId> sortedSet, SortedMap<UUID, ClientState> sortedMap, int i, int i2) {
        return tasksCost(sortedSet, sortedMap, i, i2, (v0, v1) -> {
            return v0.hasActiveTask(v1);
        }, false, false);
    }

    long standByTasksCost(SortedSet<TaskId> sortedSet, SortedMap<UUID, ClientState> sortedMap, int i, int i2) {
        return tasksCost(sortedSet, sortedMap, i, i2, (v0, v1) -> {
            return v0.hasStandbyTask(v1);
        }, true, true);
    }

    private long tasksCost(SortedSet<TaskId> sortedSet, SortedMap<UUID, ClientState> sortedMap, int i, int i2, BiPredicate<ClientState, TaskId> biPredicate, boolean z, boolean z2) {
        if (sortedSet.isEmpty()) {
            return 0L;
        }
        return constructTaskGraph(new ArrayList(sortedMap.keySet()), new ArrayList(sortedSet), sortedMap, new HashMap(), new HashMap(), biPredicate, i, i2, z, z2).totalCost();
    }

    public long optimizeActiveTasks(SortedSet<TaskId> sortedSet, SortedMap<UUID, ClientState> sortedMap, int i, int i2) {
        if (sortedSet.isEmpty()) {
            return 0L;
        }
        log.info("Assignment before active task optimization is {}\n with cost {}", sortedMap, Long.valueOf(activeTasksCost(sortedSet, sortedMap, i, i2)));
        long milliseconds = this.time.milliseconds();
        ArrayList arrayList = new ArrayList(sortedMap.keySet());
        ArrayList arrayList2 = new ArrayList(sortedSet);
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        Graph<Integer> constructTaskGraph = constructTaskGraph(arrayList, arrayList2, sortedMap, hashMap, hashMap2, (v0, v1) -> {
            return v0.hasActiveTask(v1);
        }, i, i2, false, false);
        constructTaskGraph.solveMinCostFlow();
        long j = constructTaskGraph.totalCost();
        assignTaskFromMinCostFlow(constructTaskGraph, arrayList, arrayList2, sortedMap, hashMap2, hashMap, (v0, v1) -> {
            v0.assignActive(v1);
        }, (v0, v1) -> {
            v0.unassignActive(v1);
        }, (v0, v1) -> {
            return v0.hasActiveTask(v1);
        });
        log.info("Assignment after {} milliseconds for active task optimization is {}\n with cost {}", new Object[]{Long.valueOf(this.time.milliseconds() - milliseconds), sortedMap, Long.valueOf(j)});
        return j;
    }

    public long optimizeStandbyTasks(SortedMap<UUID, ClientState> sortedMap, int i, int i2, MoveStandbyTaskPredicate moveStandbyTaskPredicate) {
        BiFunction biFunction = (clientState, clientState2) -> {
            return (List) clientState.standbyTasks().stream().filter(taskId -> {
                return !clientState2.hasAssignedTask(taskId);
            }).filter(taskId2 -> {
                return moveStandbyTaskPredicate.canMove(clientState, clientState2, taskId2, sortedMap);
            }).sorted().collect(Collectors.toList());
        };
        long milliseconds = this.time.milliseconds();
        ArrayList arrayList = new ArrayList(sortedMap.keySet());
        TreeSet treeSet = new TreeSet();
        sortedMap.values().forEach(clientState3 -> {
            treeSet.addAll(clientState3.standbyTasks());
        });
        log.info("Assignment before standby task optimization is {}\n with cost {}", sortedMap, Long.valueOf(standByTasksCost(treeSet, sortedMap, i, i2)));
        boolean z = true;
        int i3 = 0;
        while (z && i3 < STANDBY_OPTIMIZER_MAX_ITERATION) {
            z = false;
            i3++;
            for (int i4 = 0; i4 < arrayList.size(); i4++) {
                ClientState clientState4 = sortedMap.get(arrayList.get(i4));
                for (int i5 = i4 + 1; i5 < arrayList.size(); i5++) {
                    ClientState clientState5 = sortedMap.get(arrayList.get(i5));
                    if (!this.racksForProcess.get(clientState4.processId()).equals(this.racksForProcess.get(clientState5.processId()))) {
                        List list = (List) biFunction.apply(clientState4, clientState5);
                        List list2 = (List) biFunction.apply(clientState5, clientState4);
                        if (!list.isEmpty() && !list2.isEmpty()) {
                            List<TaskId> list3 = (List) Stream.concat(list.stream(), list2.stream()).sorted().collect(Collectors.toList());
                            HashMap hashMap = new HashMap();
                            List<UUID> list4 = (List) Stream.of((Object[]) new UUID[]{(UUID) arrayList.get(i4), (UUID) arrayList.get(i5)}).sorted().collect(Collectors.toList());
                            HashMap hashMap2 = new HashMap();
                            Graph<Integer> constructTaskGraph = constructTaskGraph(list4, list3, sortedMap, hashMap, hashMap2, (v0, v1) -> {
                                return v0.hasStandbyTask(v1);
                            }, i, i2, true, true);
                            constructTaskGraph.solveMinCostFlow();
                            z |= assignTaskFromMinCostFlow(constructTaskGraph, list4, list3, sortedMap, hashMap2, hashMap, (v0, v1) -> {
                                v0.assignStandby(v1);
                            }, (v0, v1) -> {
                                v0.unassignStandby(v1);
                            }, (v0, v1) -> {
                                return v0.hasStandbyTask(v1);
                            });
                        }
                    }
                }
            }
        }
        long standByTasksCost = standByTasksCost(treeSet, sortedMap, i, i2);
        log.info("Assignment after {} rounds and {} milliseconds for standby task optimization is {}\n with cost {}", new Object[]{Integer.valueOf(i3), Long.valueOf(this.time.milliseconds() - milliseconds), sortedMap, Long.valueOf(standByTasksCost)});
        return standByTasksCost;
    }

    private Graph<Integer> constructTaskGraph(List<UUID> list, List<TaskId> list2, Map<UUID, ClientState> map, Map<TaskId, UUID> map2, Map<UUID, Integer> map3, BiPredicate<ClientState, TaskId> biPredicate, int i, int i2, boolean z, boolean z2) {
        Graph<Integer> graph = new Graph<>();
        for (TaskId taskId : list2) {
            for (Map.Entry<UUID, ClientState> entry : map.entrySet()) {
                if (biPredicate.test(entry.getValue(), taskId)) {
                    map3.merge(entry.getKey(), 1, (v0, v1) -> {
                        return Integer.sum(v0, v1);
                    });
                }
            }
        }
        for (int i3 = 0; i3 < list2.size(); i3++) {
            TaskId taskId2 = list2.get(i3);
            for (int i4 = 0; i4 < list.size(); i4++) {
                int clientNodeId = getClientNodeId(list2, i4);
                UUID uuid = list.get(i4);
                int i5 = biPredicate.test(map.get(uuid), taskId2) ? 1 : 0;
                int cost = getCost(taskId2, uuid, i5 == 1, i, i2, z2);
                if (i5 == 1) {
                    if (!z && map2.containsKey(taskId2)) {
                        throw new IllegalArgumentException("Task " + taskId2 + " assigned to multiple clients " + uuid + ", " + map2.get(taskId2));
                    }
                    map2.put(taskId2, uuid);
                }
                graph.addEdge(Integer.valueOf(i3), Integer.valueOf(clientNodeId), 1, cost, i5);
            }
            if (!map2.containsKey(taskId2)) {
                throw new IllegalArgumentException("Task " + taskId2 + " not assigned to any client");
            }
            graph.addEdge(-1, Integer.valueOf(i3), 1, 0, 1);
        }
        int sinkNodeID = getSinkNodeID(list, list2);
        for (int i6 = 0; i6 < list.size(); i6++) {
            int clientNodeId2 = getClientNodeId(list2, i6);
            int intValue = map3.getOrDefault(list.get(i6), 0).intValue();
            graph.addEdge(Integer.valueOf(clientNodeId2), Integer.valueOf(sinkNodeID), intValue, 0, intValue);
        }
        graph.setSourceNode(-1);
        graph.setSinkNode(Integer.valueOf(sinkNodeID));
        return graph;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private boolean assignTaskFromMinCostFlow(Graph<Integer> graph, List<UUID> list, List<TaskId> list2, Map<UUID, ClientState> map, Map<UUID, Integer> map2, Map<TaskId, UUID> map3, BiConsumer<ClientState, TaskId> biConsumer, BiConsumer<ClientState, TaskId> biConsumer2, BiPredicate<ClientState, TaskId> biPredicate) {
        int i = 0;
        boolean z = false;
        for (int i2 = 0; i2 < list2.size(); i2++) {
            TaskId taskId = list2.get(i2);
            for (Graph<V>.Edge edge : graph.edges(Integer.valueOf(i2)).values()) {
                if (edge.flow > 0) {
                    i++;
                    UUID uuid = list.get(getClientIndex(list2, ((Integer) edge.destination).intValue()));
                    UUID uuid2 = map3.get(taskId);
                    if (uuid.equals(uuid2)) {
                        break;
                    }
                    biConsumer2.accept(map.get(uuid2), taskId);
                    biConsumer.accept(map.get(uuid), taskId);
                    z = true;
                }
            }
        }
        if (i != list2.size()) {
            throw new IllegalStateException("Computed active task assignment number " + i + " is different size " + list2.size());
        }
        HashMap hashMap = new HashMap();
        for (TaskId taskId2 : list2) {
            for (Map.Entry<UUID, ClientState> entry : map.entrySet()) {
                if (biPredicate.test(entry.getValue(), taskId2)) {
                    hashMap.merge(entry.getKey(), 1, (v0, v1) -> {
                        return Integer.sum(v0, v1);
                    });
                }
            }
        }
        if (map2.size() != hashMap.size()) {
            throw new IllegalStateException("There are " + map2.size() + " clients have  active tasks before assignment, but " + hashMap.size() + " clients have active tasks after assignment");
        }
        for (Map.Entry<UUID, Integer> entry2 : map2.entrySet()) {
            int intValue = ((Integer) hashMap.getOrDefault(entry2.getKey(), 0)).intValue();
            if (!Objects.equals(entry2.getValue(), Integer.valueOf(intValue))) {
                throw new IllegalStateException("There are " + entry2.getValue() + " tasks assigned to client " + entry2.getKey() + " before assignment, but " + intValue + " tasks  are assigned to it after assignment");
            }
        }
        return z;
    }
}
