package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession.class */
public abstract class AbstractSession<T, O> {
    private static final Logger log = LoggerFactory.getLogger(AbstractSession.class);
    public static final String OUTER_FRAME = "main";
    protected final SameDiff sameDiff;
    protected final Map<VarId, T> nodeOutputs = new HashMap();
    protected final Map<VarId, List<T>> tensorArrays = new HashMap();
    protected final Queue<VarId> availableForExec = new LinkedList();
    protected final Set<VarId> availableForExecSet = new HashSet();
    protected final Set<String> subgraph = new HashSet();
    protected final Map<VarId, Set<VarId>> execInputs = new HashMap();
    protected final Map<VarId, Set<VarId>> execInputsAllIter = new HashMap();
    protected final Map<String, Set<String>> execConstInputs = new HashMap();
    protected final Map<String, FrameIter> frameParents = new HashMap();

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession$FrameIter.class */
    public static class FrameIter {
        private String frame;
        private int iteration;
        private FrameIter parentFrame;

        public String toString() {
            return "(\"" + this.frame + "\"," + this.iteration + (this.parentFrame == null ? "" : ",parent=" + this.parentFrame.toString()) + ")";
        }

        /* renamed from: clone, reason: merged with bridge method [inline-methods] */
        public FrameIter m1463clone() {
            return new FrameIter(this.frame, this.iteration, this.parentFrame == null ? null : this.parentFrame.m1463clone());
        }

        public String getFrame() {
            return this.frame;
        }

        public int getIteration() {
            return this.iteration;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setFrame(String str) {
            this.frame = str;
        }

        public void setIteration(int i) {
            this.iteration = i;
        }

        public void setParentFrame(FrameIter frameIter) {
            this.parentFrame = frameIter;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof FrameIter)) {
                return false;
            }
            FrameIter frameIter = (FrameIter) obj;
            if (!frameIter.canEqual(this)) {
                return false;
            }
            String frame = getFrame();
            String frame2 = frameIter.getFrame();
            if (frame == null) {
                if (frame2 != null) {
                    return false;
                }
            } else if (!frame.equals(frame2)) {
                return false;
            }
            if (getIteration() != frameIter.getIteration()) {
                return false;
            }
            FrameIter parentFrame = getParentFrame();
            FrameIter parentFrame2 = frameIter.getParentFrame();
            return parentFrame == null ? parentFrame2 == null : parentFrame.equals(parentFrame2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof FrameIter;
        }

        public int hashCode() {
            String frame = getFrame();
            int hashCode = (((1 * 59) + (frame == null ? 43 : frame.hashCode())) * 59) + getIteration();
            FrameIter parentFrame = getParentFrame();
            return (hashCode * 59) + (parentFrame == null ? 43 : parentFrame.hashCode());
        }

        public FrameIter(String str, int i, FrameIter frameIter) {
            this.frame = str;
            this.iteration = i;
            this.parentFrame = frameIter;
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/AbstractSession$VarId.class */
    public static class VarId {
        private String variable;
        private String frame;
        private int iteration;
        private FrameIter parentFrame;

        public String toString() {
            return "VarId(\"" + this.variable + "\",\"" + this.frame + "\"," + this.iteration + ",parent=" + this.parentFrame + ")";
        }

        public FrameIter toFrameIter() {
            return new FrameIter(this.frame, this.iteration, this.parentFrame);
        }

        public String getVariable() {
            return this.variable;
        }

        public String getFrame() {
            return this.frame;
        }

        public int getIteration() {
            return this.iteration;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setVariable(String str) {
            this.variable = str;
        }

        public void setFrame(String str) {
            this.frame = str;
        }

        public void setIteration(int i) {
            this.iteration = i;
        }

        public void setParentFrame(FrameIter frameIter) {
            this.parentFrame = frameIter;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VarId)) {
                return false;
            }
            VarId varId = (VarId) obj;
            if (!varId.canEqual(this)) {
                return false;
            }
            String variable = getVariable();
            String variable2 = varId.getVariable();
            if (variable == null) {
                if (variable2 != null) {
                    return false;
                }
            } else if (!variable.equals(variable2)) {
                return false;
            }
            String frame = getFrame();
            String frame2 = varId.getFrame();
            if (frame == null) {
                if (frame2 != null) {
                    return false;
                }
            } else if (!frame.equals(frame2)) {
                return false;
            }
            if (getIteration() != varId.getIteration()) {
                return false;
            }
            FrameIter parentFrame = getParentFrame();
            FrameIter parentFrame2 = varId.getParentFrame();
            return parentFrame == null ? parentFrame2 == null : parentFrame.equals(parentFrame2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof VarId;
        }

        public int hashCode() {
            String variable = getVariable();
            int hashCode = (1 * 59) + (variable == null ? 43 : variable.hashCode());
            String frame = getFrame();
            int hashCode2 = (((hashCode * 59) + (frame == null ? 43 : frame.hashCode())) * 59) + getIteration();
            FrameIter parentFrame = getParentFrame();
            return (hashCode2 * 59) + (parentFrame == null ? 43 : parentFrame.hashCode());
        }

        public VarId(String str, String str2, int i, FrameIter frameIter) {
            this.variable = str;
            this.frame = str2;
            this.iteration = i;
            this.parentFrame = frameIter;
        }
    }

    public AbstractSession(@NonNull SameDiff sameDiff) {
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked @NonNull but is null");
        }
        this.sameDiff = sameDiff;
    }

    public boolean contains(String str, String str2, int i, FrameIter frameIter) {
        return this.nodeOutputs.containsKey(newVarId(str, str2, i, frameIter));
    }

    public T get(String str, String str2, int i, FrameIter frameIter) {
        return get(str, str2, i, frameIter, true);
    }

    public T get(String str, String str2, int i, FrameIter frameIter, boolean z) {
        T t = this.nodeOutputs.get(newVarId(str, str2, i, frameIter));
        if (z) {
            Preconditions.checkNotNull(t, "No output found for variable %s (frame %s, iteration %s)", str, str2, Integer.valueOf(i));
        }
        return t;
    }

    public VarId newVarId(String str, String str2, int i, FrameIter frameIter) {
        return new VarId(str, str2, i, frameIter);
    }

    public VarId newVarId(String str, FrameIter frameIter) {
        return newVarId(str, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame());
    }

    @Deprecated
    public Map<String, T> output(@NonNull List<String> list, Map<String, T> map, MultiDataSet multiDataSet, Collection<String> collection, boolean z, At at) {
        if (list == null) {
            throw new NullPointerException("variables is marked @NonNull but is null");
        }
        if (at == null) {
            at = z ? At.defaultAt(Operation.TRAINING) : At.defaultAt(Operation.INFERENCE);
        }
        return output(list, map, multiDataSet, collection, Collections.emptyList(), at);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Map<String, T> output(@NonNull List<String> list, Map<String, T> map, MultiDataSet multiDataSet, Collection<String> collection, List<Listener> list2, At at) {
        VarId newVarId;
        if (list == null) {
            throw new NullPointerException("variables is marked @NonNull but is null");
        }
        Preconditions.checkState(!list.isEmpty(), "Variables to perform forward pass for must not be empty");
        if (collection == null) {
            collection = Collections.emptyList();
        }
        if (at == null) {
            at = At.defaultAt();
        }
        for (String str : list) {
            Preconditions.checkState(this.sameDiff.variableMap().containsKey(str), "Requested output variable %s does not exist in SameDiff instance", str);
        }
        Map preprocessPlaceholders = preprocessPlaceholders(map);
        this.availableForExec.clear();
        this.availableForExecSet.clear();
        this.subgraph.clear();
        this.execInputs.clear();
        this.execInputsAllIter.clear();
        this.execConstInputs.clear();
        this.nodeOutputs.clear();
        this.tensorArrays.clear();
        ArrayList arrayList = new ArrayList(collection);
        arrayList.addAll(list);
        initSubgraph(arrayList);
        List<String> inputs = this.sameDiff.inputs();
        if (preprocessPlaceholders == null || !preprocessPlaceholders.keySet().containsAll(inputs)) {
            for (String str2 : inputs) {
                boolean z = list.contains(str2);
                if (!z) {
                    Variable variable = this.sameDiff.getVariables().get(str2);
                    if (variable.getInputsForOp() != null) {
                        Iterator<String> it = variable.getInputsForOp().iterator();
                        while (true) {
                            if (!it.hasNext()) {
                                break;
                            }
                            if (this.subgraph.contains(it.next())) {
                                z = true;
                                break;
                            }
                        }
                    }
                }
                if (z && (preprocessPlaceholders == null || !preprocessPlaceholders.containsKey(str2))) {
                    if (!str2.endsWith("keras_learning_phase")) {
                        throw new IllegalStateException("An input placeholder \"" + str2 + "\" is required to calculate the requested outputs, but a placeholder value was not provided");
                    }
                    preprocessPlaceholders.put(str2, Nd4j.scalar(at.operation().isTrainingPhase()));
                }
            }
        }
        HashMap hashMap = new HashMap();
        int i = 0;
        while (hashMap.size() < list.size()) {
            if (this.availableForExec.size() == 0) {
                int size = list.size() - hashMap.size();
                StringBuilder sb = new StringBuilder();
                sb.append("No variable are available for execution at step ").append(i).append(": ").append(size).append(" values remaining");
                HashSet hashSet = new HashSet();
                for (String str3 : list) {
                    if (!hashMap.containsKey(str3)) {
                        hashSet.add(str3);
                    }
                }
                if (size <= 10) {
                    sb.append(". Missing variables: ");
                    sb.append(hashSet);
                } else {
                    sb.append(". First 10 missing variables: ");
                    Iterator it2 = hashSet.iterator();
                    for (int i2 = 0; i2 < 10 && it2.hasNext(); i2++) {
                        if (i2 > 0) {
                            sb.append(",");
                        }
                        sb.append((String) it2.next());
                    }
                }
                throw new IllegalStateException(sb.toString());
            }
            VarId remove = this.availableForExec.remove();
            this.availableForExecSet.remove(remove);
            if (this.nodeOutputs.containsKey(remove)) {
                if (list.contains(remove.getVariable())) {
                    hashMap.put(remove.getVariable(), this.nodeOutputs.get(remove));
                }
                updateDescendentsForExec(i, remove);
            } else {
                Set<VarId> set = this.execInputs.get(remove);
                Set<VarId> set2 = this.execInputsAllIter.get(newVarId(remove.getVariable(), remove.getFrame(), 0, remove.getParentFrame()));
                Set<String> set3 = this.execConstInputs.get(remove.getVariable());
                log.trace("Beginning execution step {}: variable {}", Integer.valueOf(i), remove);
                if (this.sameDiff.getVariable(remove.getVariable()).isPlaceHolder()) {
                    this.nodeOutputs.put(remove, preprocessPlaceholders.get(remove.getVariable()));
                    updateDescendentsForExec(i, remove);
                    if (list.contains(remove.getVariable())) {
                        hashMap.put(remove.getVariable(), preprocessPlaceholders.get(remove.getVariable()));
                    }
                } else if (this.sameDiff.getVariable(remove.getVariable()).isConstant() || this.sameDiff.getVariable(remove.getVariable()).getVariableType() == VariableType.VARIABLE) {
                    Object constantOrVariable = getConstantOrVariable(remove.getVariable());
                    Preconditions.checkNotNull(constantOrVariable, "Encountered null placeholder array for constant: %s", remove);
                    this.nodeOutputs.put(remove, constantOrVariable);
                    updateDescendentsForExec(i, remove);
                    if (list.contains(remove.getVariable())) {
                        hashMap.put(remove.getVariable(), constantOrVariable);
                    }
                } else {
                    if (this.sameDiff.getVariableOutputOp(remove.getVariable()) == null) {
                        throw new IllegalStateException("Unable to execute variable " + remove + " of type " + this.sameDiff.getVariables().get(remove.getVariable()).getVariable().getVariableType());
                    }
                    String outputOfOp = this.sameDiff.getVariables().get(remove.getVariable()).getOutputOfOp();
                    FrameIter frameIter = remove.toFrameIter();
                    Object andParameterizeOp = getAndParameterizeOp(outputOfOp, frameIter, set, set2, set3, preprocessPlaceholders);
                    Object[] outputs = getOutputs(andParameterizeOp, frameIter, set, set2, set3, list2, at, multiDataSet);
                    String[] outputVariablesNames = this.sameDiff.getOpById(outputOfOp).outputVariablesNames();
                    Preconditions.checkState(outputs.length == outputVariablesNames.length, "Unexpected number of outputs from executed op %s: got %s outputs when %s outputs were expected (%s)", andParameterizeOp.getClass().getSimpleName(), Integer.valueOf(outputs.length), Integer.valueOf(outputVariablesNames.length), outputVariablesNames);
                    for (int i3 = 0; i3 < outputVariablesNames.length; i3++) {
                        if (outputs[i3] != null || !(andParameterizeOp instanceof Switch)) {
                            Preconditions.checkNotNull(outputs[i3], "Encountered null output (output %s) for op %s at execution step %s", Integer.valueOf(i3), andParameterizeOp.getClass().getSimpleName(), Integer.valueOf(i));
                            boolean z2 = false;
                            if (andParameterizeOp instanceof Enter) {
                                String frameName = ((Enter) andParameterizeOp).getFrameName();
                                boolean isConstant = ((Enter) andParameterizeOp).isConstant();
                                FrameIter parentFrame = remove.getParentFrame();
                                if (isConstant && parentFrame != null) {
                                    parentFrame = parentFrame.m1463clone();
                                    FrameIter frameIter2 = parentFrame;
                                    while (true) {
                                        FrameIter frameIter3 = frameIter2;
                                        if (frameIter3 == null) {
                                            break;
                                        }
                                        frameIter3.setIteration(0);
                                        frameIter2 = frameIter3.getParentFrame();
                                    }
                                }
                                newVarId = newVarId(outputVariablesNames[i3], frameName, 0, parentFrame);
                                z2 = true;
                            } else if (andParameterizeOp instanceof Exit) {
                                newVarId = newVarId(outputVariablesNames[i3], remove.getFrame(), remove.getIteration(), remove.getParentFrame());
                                z2 = true;
                            } else if (andParameterizeOp instanceof NextIteration) {
                                newVarId = newVarId(outputVariablesNames[i3], remove.getFrame(), remove.getIteration(), remove.getParentFrame());
                                z2 = true;
                            } else if (andParameterizeOp instanceof LoopCond) {
                                newVarId = newVarId(outputVariablesNames[i3], remove.getFrame(), remove.getIteration(), remove.getParentFrame());
                                z2 = true;
                            } else {
                                newVarId = newVarId(outputVariablesNames[i3], remove.getFrame(), remove.getIteration(), remove.getParentFrame());
                            }
                            if (z2) {
                                this.nodeOutputs.put(newVarId(outputVariablesNames[i3], remove.getFrame(), remove.getIteration(), remove.getParentFrame()), null);
                            }
                            this.nodeOutputs.put(newVarId, outputs[i3]);
                            updateDescendentsForExec(i, newVarId);
                            if (list.contains(outputVariablesNames[i3])) {
                                hashMap.put(outputVariablesNames[i3], outputs[i3]);
                            }
                        }
                    }
                }
                i++;
            }
        }
        return hashMap;
    }

    protected void initSubgraph(List<String> list) {
        LinkedList linkedList = new LinkedList(list);
        while (!linkedList.isEmpty()) {
            String str = (String) linkedList.remove();
            String ownName = this.sameDiff.getVariableOutputOp(str) == null ? null : this.sameDiff.getVariableOutputOp(str).getOwnName();
            if (!this.subgraph.contains(str)) {
                String[] inputsForOp = ownName == null ? null : this.sameDiff.getInputsForOp(this.sameDiff.getOpById(ownName));
                List<String> controlDeps = this.sameDiff.getVariables().get(str).getControlDeps();
                int length = inputsForOp == null ? 0 : inputsForOp.length;
                if (controlDeps != null) {
                    length += controlDeps.size();
                }
                if (length == 0) {
                    VarId newVarId = newVarId(str, OUTER_FRAME, 0, null);
                    if (!this.availableForExecSet.contains(newVarId)) {
                        this.availableForExec.add(newVarId);
                        this.availableForExecSet.add(newVarId);
                    }
                    this.execInputs.put(newVarId, new HashSet());
                }
                this.subgraph.add(str);
                if (controlDeps != null) {
                    for (String str2 : controlDeps) {
                        if (!this.subgraph.contains(str2)) {
                            linkedList.add(str2);
                        }
                    }
                }
            }
            if (ownName != null) {
                for (String str3 : this.sameDiff.getInputsForOp(this.sameDiff.getOpById(ownName))) {
                    if (!this.subgraph.contains(str3)) {
                        linkedList.add(str3);
                    }
                }
                List<String> controlDeps2 = this.sameDiff.getOps().get(ownName).getControlDeps();
                if (controlDeps2 != null) {
                    for (String str4 : controlDeps2) {
                        if (!this.subgraph.contains(str4)) {
                            linkedList.add(str4);
                        }
                    }
                }
            }
        }
    }

    protected void updateDescendentsForExec(int i, VarId varId) {
        varId.getVariable();
        Variable variable = this.sameDiff.getVariables().get(varId.getVariable());
        List<String> inputsForOp = this.sameDiff.getVariables().get(varId.getVariable()).getInputsForOp();
        String[] strArr = inputsForOp == null ? null : (String[]) inputsForOp.toArray(new String[inputsForOp.size()]);
        List<String> controlDepsForVar = variable.getControlDepsForVar();
        List<String> controlDepsForOp = variable.getControlDepsForOp();
        SDVariable variable2 = variable.getVariable();
        boolean z = variable2.isPlaceHolder() || variable2.isConstant();
        if (strArr != null) {
            for (String str : strArr) {
                DifferentialFunction opById = this.sameDiff.getOpById(str);
                if (opById instanceof Merge) {
                    List<String> outputsOfOp = this.sameDiff.getOps().get(str).getOutputsOfOp();
                    Preconditions.checkState(outputsOfOp.size() == 1, "Expected only 1 output variable for merge op, got %s", outputsOfOp);
                    VarId newVarId = newVarId(outputsOfOp.get(0), varId.getFrame(), varId.getIteration(), varId.getParentFrame());
                    if (!this.nodeOutputs.containsKey(newVarId) && this.subgraph.contains(newVarId.getVariable()) && !this.availableForExecSet.contains(newVarId)) {
                        this.availableForExec.add(newVarId);
                        this.availableForExecSet.add(newVarId);
                        log.trace("Marked merge op ({}) variable {} as available for execution: input {} is now available", new Object[]{str, newVarId, varId});
                    }
                    addToExecInputs(z, varId, newVarId);
                } else if (opById instanceof Enter) {
                    List<String> outputsOfOp2 = this.sameDiff.getOps().get(str).getOutputsOfOp();
                    Preconditions.checkState(outputsOfOp2.size() == 1, "Expected only 1 output variable for enter op, got %s", outputsOfOp2);
                    Enter enter = (Enter) opById;
                    boolean isConstant = enter.isConstant();
                    VarId newVarId2 = newVarId(outputsOfOp2.get(0), enter.getFrameName(), 0, varId.toFrameIter());
                    if (isConstant && varId.getParentFrame() != null) {
                        newVarId2.setParentFrame(newVarId2.getParentFrame().m1463clone());
                        FrameIter parentFrame = newVarId2.getParentFrame();
                        while (true) {
                            FrameIter frameIter = parentFrame;
                            if (frameIter == null) {
                                break;
                            }
                            frameIter.setIteration(0);
                            parentFrame = frameIter.getParentFrame();
                        }
                    }
                    if (!this.nodeOutputs.containsKey(newVarId2) && this.subgraph.contains(newVarId2.getVariable()) && !this.availableForExecSet.contains(newVarId2)) {
                        this.availableForExec.add(newVarId2);
                        this.availableForExecSet.add(newVarId2);
                        log.trace("Marked enter op ({}) variable {} as available for execution: input {} is now available", new Object[]{str, newVarId2, varId});
                    }
                    this.frameParents.put(enter.getFrameName(), varId.toFrameIter());
                    addToExecInputs(z, varId, newVarId2);
                } else if (opById instanceof Exit) {
                    List<String> outputsOfOp3 = this.sameDiff.getOps().get(str).getOutputsOfOp();
                    FrameIter frameIter2 = this.frameParents.get(varId.getFrame());
                    Preconditions.checkNotNull(frameIter2, "Parent frame must not be null for exit op: variable to exec is %s", varId);
                    VarId varId2 = new VarId(outputsOfOp3.get(0), frameIter2.getFrame(), frameIter2.getIteration(), varId.getParentFrame().getParentFrame());
                    if (!this.nodeOutputs.containsKey(varId2) && this.subgraph.contains(varId2.getVariable()) && !this.availableForExecSet.contains(varId2)) {
                        this.availableForExec.add(varId2);
                        this.availableForExecSet.add(varId2);
                        log.trace("Marked Exit op ({}) variable {} as available for execution: input {} is now available", new Object[]{str, varId2, varId});
                    }
                    addToExecInputs(z, varId, varId2);
                } else if (opById instanceof NextIteration) {
                    List<String> outputsOfOp4 = this.sameDiff.getOps().get(str).getOutputsOfOp();
                    Preconditions.checkState(outputsOfOp4.size() == 1, "Expected exactly 1 output for NextIteration op: got %s", outputsOfOp4);
                    VarId newVarId3 = newVarId(outputsOfOp4.get(0), varId.getFrame(), varId.getIteration() + 1, varId.getParentFrame());
                    if (!this.nodeOutputs.containsKey(newVarId3) && this.subgraph.contains(newVarId3.getVariable()) && !this.availableForExecSet.contains(newVarId3)) {
                        this.availableForExec.add(newVarId3);
                        this.availableForExecSet.add(newVarId3);
                        log.trace("Marked NextIteration op ({}) variable {} as available for execution: input {} is now available", new Object[]{str, newVarId3, varId});
                    }
                    addToExecInputs(z, varId, newVarId3);
                } else {
                    String[] argNames = opById.argNames();
                    boolean allInputsAvailable = argNames != null ? allInputsAvailable(i, argNames, varId) : true;
                    List<String> controlDeps = this.sameDiff.getOps().get(str).getControlDeps();
                    if (controlDeps != null && allInputsAvailable) {
                        Iterator<String> it = controlDeps.iterator();
                        while (true) {
                            if (it.hasNext()) {
                                if (!this.nodeOutputs.containsKey(newVarId(it.next(), varId.getFrame(), varId.getIteration(), varId.getParentFrame()))) {
                                    allInputsAvailable = false;
                                    break;
                                }
                            } else {
                                break;
                            }
                        }
                    }
                    List<String> outputsOfOp5 = this.sameDiff.getOps().get(str).getOutputsOfOp();
                    if (outputsOfOp5 != null) {
                        for (String str2 : outputsOfOp5) {
                            SDVariable variable3 = this.sameDiff.getVariable(str2);
                            Variable variable4 = this.sameDiff.getVariables().get(str2);
                            addToExecInputs(z, varId, (variable3.isConstant() || variable3.isPlaceHolder()) ? (variable4.getControlDeps() == null || variable.getControlDeps().isEmpty()) ? newVarId(str2, OUTER_FRAME, 0, null) : newVarId(str2, varId.getFrame(), varId.getIteration(), varId.getParentFrame()) : newVarId(str2, varId.getFrame(), varId.getIteration(), varId.getParentFrame()));
                            if (allInputsAvailable && variable4.getControlDeps() != null && !variable4.getControlDeps().isEmpty()) {
                                for (String str3 : variable4.getControlDeps()) {
                                    Variable variable5 = this.sameDiff.getVariables().get(str3);
                                    allInputsAvailable &= this.nodeOutputs.containsKey((variable5.getVariable().isConstant() || variable5.getVariable().isPlaceHolder()) ? (variable4.getControlDeps() == null || variable.getControlDeps().isEmpty()) ? newVarId(str3, OUTER_FRAME, 0, null) : newVarId(str3, varId.getFrame(), varId.getIteration(), varId.getParentFrame()) : newVarId(str3, varId.getFrame(), varId.getIteration(), varId.getParentFrame()));
                                    if (!allInputsAvailable) {
                                        break;
                                    }
                                }
                            }
                        }
                        if (allInputsAvailable) {
                            for (String str4 : outputsOfOp5) {
                                if (this.subgraph.contains(str4)) {
                                    VarId newVarId4 = newVarId(str4, varId.getFrame(), varId.getIteration(), varId.getParentFrame());
                                    if (!this.availableForExecSet.contains(newVarId4)) {
                                        this.availableForExec.add(newVarId4);
                                        this.availableForExecSet.add(newVarId4);
                                        Logger logger = log;
                                        Object[] objArr = new Object[4];
                                        objArr[0] = newVarId4;
                                        objArr[1] = str;
                                        objArr[2] = opById.getClass().getSimpleName();
                                        objArr[3] = argNames == null ? "<none>" : Arrays.toString(argNames);
                                        logger.trace("Marked variable as available for execution: {} - output of op {} ({}) with op inputs {}", objArr);
                                    }
                                }
                            }
                        }
                    }
                }
            }
        }
        if (controlDepsForVar != null) {
            for (String str5 : controlDepsForVar) {
                if (this.subgraph.contains(str5)) {
                    if (this.sameDiff.getVariable(str5).getVariableType() != VariableType.ARRAY) {
                        VarId newVarId5 = newVarId(str5, varId.getFrame(), varId.getIteration(), varId.getParentFrame());
                        if (!this.availableForExecSet.contains(newVarId5)) {
                            this.availableForExec.add(newVarId5);
                            this.availableForExecSet.add(newVarId5);
                            log.trace("Marked variable as available for execution: {} - control dependency {} -> {} exists", new Object[]{newVarId5, varId.getVariable(), str5});
                        }
                    } else {
                        String outputOfOp = this.sameDiff.getVariables().get(str5).getOutputOfOp();
                        if (outputOfOp != null) {
                            SameDiffOp sameDiffOp = this.sameDiff.getOps().get(outputOfOp);
                            boolean z2 = true;
                            if (sameDiffOp.getInputsToOp() != null && !sameDiffOp.getInputsToOp().isEmpty()) {
                                List<String> inputsToOp = sameDiffOp.getInputsToOp();
                                z2 = allInputsAvailable(i, (String[]) inputsToOp.toArray(new String[inputsToOp.size()]), varId);
                            }
                            if (z2 && sameDiffOp.getControlDeps() != null) {
                                Iterator<String> it2 = sameDiffOp.getControlDeps().iterator();
                                while (it2.hasNext()) {
                                    z2 &= this.nodeOutputs.containsKey(newVarId(it2.next(), varId.getFrame(), varId.getIteration(), varId.getParentFrame()));
                                    if (!z2) {
                                        break;
                                    }
                                }
                            }
                            if (z2) {
                                Iterator<String> it3 = sameDiffOp.getOutputsOfOp().iterator();
                                while (it3.hasNext()) {
                                    Variable variable6 = this.sameDiff.getVariables().get(it3.next());
                                    if (variable6.getControlDeps() != null) {
                                        Iterator<String> it4 = variable6.getControlDeps().iterator();
                                        while (it4.hasNext()) {
                                            z2 &= this.nodeOutputs.containsKey(newVarId(it4.next(), varId.getFrame(), varId.getIteration(), varId.getParentFrame()));
                                            if (!z2) {
                                                break;
                                            }
                                        }
                                    }
                                }
                            }
                            if (z2) {
                                VarId newVarId6 = newVarId(str5, varId.getFrame(), varId.getIteration(), varId.getParentFrame());
                                if (!this.availableForExecSet.contains(newVarId6)) {
                                    this.availableForExec.add(newVarId6);
                                    log.trace("Marked variable as available for execution: {} - is output of op {} with no inputs (but has control dependencies)", newVarId6, sameDiffOp.getName());
                                }
                            }
                        }
                    }
                }
            }
        }
        if (controlDepsForOp != null) {
            for (String str6 : controlDepsForOp) {
                SameDiffOp sameDiffOp2 = this.sameDiff.getOps().get(str6);
                if (sameDiffOp2.getInputsToOp() == null || sameDiffOp2.getInputsToOp().isEmpty()) {
                    for (String str7 : sameDiffOp2.getOutputsOfOp()) {
                        if (this.subgraph.contains(str7)) {
                            VarId newVarId7 = newVarId(str7, OUTER_FRAME, 0, null);
                            if (!this.availableForExecSet.contains(newVarId7)) {
                                this.availableForExec.add(newVarId7);
                                this.availableForExecSet.add(newVarId7);
                                log.trace("Marked variable as available for execution: {} - op control dependency variable {} -> op {} exists", new Object[]{newVarId7, varId.getVariable(), str6});
                            }
                        }
                    }
                }
            }
        }
    }

    protected boolean allInputsAvailable(int i, String[] strArr, VarId varId) {
        VarId newVarId;
        for (String str : strArr) {
            SDVariable variable = this.sameDiff.getVariable(str);
            Variable variable2 = this.sameDiff.getVariables().get(str);
            if (variable.isConstant() || variable.isPlaceHolder()) {
                newVarId = (variable2.getControlDeps() == null || variable2.getControlDeps().isEmpty()) ? newVarId(str, OUTER_FRAME, 0, null) : newVarId(str, varId.getFrame(), varId.getIteration(), varId.getParentFrame());
            } else {
                int iteration = varId.getIteration();
                FrameIter parentFrame = varId.getParentFrame();
                if (variable.getVariableType() == VariableType.ARRAY && (this.sameDiff.getOps().get(variable2.getOutputOfOp()).getOp() instanceof Enter)) {
                    iteration = 0;
                    if (((Enter) this.sameDiff.getOps().get(variable2.getOutputOfOp()).getOp()).isConstant()) {
                        parentFrame = parentFrame.m1463clone();
                        FrameIter frameIter = parentFrame;
                        while (true) {
                            FrameIter frameIter2 = frameIter;
                            if (frameIter2 == null) {
                                break;
                            }
                            frameIter2.setIteration(0);
                            frameIter = frameIter2.getParentFrame();
                        }
                    }
                }
                newVarId = newVarId(str, varId.getFrame(), iteration, parentFrame);
            }
            if (!this.nodeOutputs.containsKey(newVarId)) {
                return false;
            }
        }
        return true;
    }

    protected Map<String, T> preprocessPlaceholders(Map<String, T> map) {
        return map;
    }

    public abstract T getConstantOrVariable(String str);

    public abstract O getAndParameterizeOp(String str, FrameIter frameIter, Set<VarId> set, Set<VarId> set2, Set<String> set3, Map<String, T> map);

    public abstract T[] getOutputs(O o, FrameIter frameIter, Set<VarId> set, Set<VarId> set2, Set<String> set3, List<Listener> list, At at, MultiDataSet multiDataSet);

    protected void addToExecInputs(boolean z, VarId varId, VarId varId2) {
        if (this.subgraph.contains(varId2.getVariable())) {
            if (z) {
                if (!this.execConstInputs.containsKey(varId2.getVariable())) {
                    this.execConstInputs.put(varId2.getVariable(), new HashSet());
                }
                this.execConstInputs.get(varId2.getVariable()).add(varId.getVariable());
                return;
            }
            if (!(this.sameDiff.getVariableOutputOp(this.sameDiff.getVariables().get(varId.getVariable()).getVariable().getVarName()) instanceof Enter)) {
                if (!this.execInputs.containsKey(varId2)) {
                    this.execInputs.put(varId2, new HashSet());
                }
                this.execInputs.get(varId2).add(varId);
                return;
            }
            VarId varId3 = varId2;
            if (varId3.getIteration() != 0) {
                varId3 = newVarId(varId3.getVariable(), varId3.getFrame(), 0, varId2.getParentFrame());
            }
            if (((Enter) this.sameDiff.getOps().get(this.sameDiff.getVariables().get(varId.getVariable()).getOutputOfOp()).getOp()).isConstant()) {
                varId3.setParentFrame(varId3.getParentFrame().m1463clone());
                FrameIter parentFrame = varId3.getParentFrame();
                while (true) {
                    FrameIter frameIter = parentFrame;
                    if (frameIter == null) {
                        break;
                    }
                    frameIter.setIteration(0);
                    parentFrame = frameIter.getParentFrame();
                }
            }
            if (!this.execInputsAllIter.containsKey(varId3)) {
                this.execInputsAllIter.put(varId3, new HashSet());
            }
            this.execInputsAllIter.get(varId3).add(varId);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static VarId lookup(String str, Collection<VarId> collection, boolean z) {
        for (VarId varId : collection) {
            if (varId.getVariable().equals(str)) {
                return varId;
            }
        }
        if (z) {
            throw new RuntimeException("Could not find VarId to input " + str);
        }
        return null;
    }

    public Map<VarId, T> getNodeOutputs() {
        return this.nodeOutputs;
    }

    public Map<VarId, List<T>> getTensorArrays() {
        return this.tensorArrays;
    }

    public Map<String, FrameIter> getFrameParents() {
        return this.frameParents;
    }
}
