package org.apache.spark.sql.catalyst.optimizer;

import org.apache.spark.SparkException$;
import org.apache.spark.sql.catalyst.expressions.Alias;
import org.apache.spark.sql.catalyst.expressions.ArrayTransform;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Coalesce;
import org.apache.spark.sql.catalyst.expressions.CreateArray;
import org.apache.spark.sql.catalyst.expressions.CreateMap;
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GetStructField;
import org.apache.spark.sql.catalyst.expressions.GetStructField$;
import org.apache.spark.sql.catalyst.expressions.If;
import org.apache.spark.sql.catalyst.expressions.IsNull;
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized;
import org.apache.spark.sql.catalyst.expressions.LambdaFunction;
import org.apache.spark.sql.catalyst.expressions.LambdaFunction$;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.expressions.Literal$;
import org.apache.spark.sql.catalyst.expressions.NamedLambdaVariable;
import org.apache.spark.sql.catalyst.expressions.NamedLambdaVariable$;
import org.apache.spark.sql.catalyst.expressions.TransformValues;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.rules.Rule;
import org.apache.spark.sql.catalyst.trees.TreePattern$;
import org.apache.spark.sql.catalyst.trees.TreePatternBits;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.FloatType$;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.ArrayImplicits$;
import scala.Enumeration;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.ArrayOps$;
import scala.collection.immutable.$colon;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Seq;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: NormalizeFloatingNumbers.scala */
/* loaded from: input_file:org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers$.class */
public final class NormalizeFloatingNumbers$ extends Rule<LogicalPlan> {
    public static final NormalizeFloatingNumbers$ MODULE$ = new NormalizeFloatingNumbers$();
    private static final Function1<Object, Object> FLOAT_NORMALIZER = obj -> {
        float unboxToFloat = BoxesRunTime.unboxToFloat(obj);
        return Float.isNaN(unboxToFloat) ? BoxesRunTime.boxToFloat(Float.NaN) : unboxToFloat == -0.0f ? BoxesRunTime.boxToFloat(0.0f) : BoxesRunTime.boxToFloat(unboxToFloat);
    };
    private static final Function1<Object, Object> DOUBLE_NORMALIZER = obj -> {
        double unboxToDouble = BoxesRunTime.unboxToDouble(obj);
        return Double.isNaN(unboxToDouble) ? BoxesRunTime.boxToDouble(Double.NaN) : unboxToDouble == -0.0d ? BoxesRunTime.boxToDouble(0.0d) : BoxesRunTime.boxToDouble(unboxToDouble);
    };

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.spark.sql.catalyst.rules.Rule
    public LogicalPlan apply(LogicalPlan logicalPlan) {
        return (LogicalPlan) logicalPlan.transformWithPruning(treePatternBits -> {
            return BoxesRunTime.boxToBoolean($anonfun$apply$1(treePatternBits));
        }, logicalPlan.transformWithPruning$default$2(), new NormalizeFloatingNumbers$$anonfun$apply$2());
    }

    public boolean org$apache$spark$sql$catalyst$optimizer$NormalizeFloatingNumbers$$needNormalize(Expression expression) {
        if (expression instanceof KnownFloatingPointNormalized) {
            return false;
        }
        return needNormalize(expression.mo283dataType());
    }

    private boolean needNormalize(DataType dataType) {
        while (true) {
            DataType dataType2 = dataType;
            if (FloatType$.MODULE$.equals(dataType2) ? true : DoubleType$.MODULE$.equals(dataType2)) {
                return true;
            }
            if (dataType2 instanceof StructType) {
                return ArrayOps$.MODULE$.exists$extension(Predef$.MODULE$.refArrayOps(((StructType) dataType2).fields()), structField -> {
                    return BoxesRunTime.boxToBoolean($anonfun$needNormalize$1(structField));
                });
            }
            if (dataType2 instanceof ArrayType) {
                dataType = ((ArrayType) dataType2).elementType();
            } else {
                if (!(dataType2 instanceof MapType)) {
                    return false;
                }
                dataType = ((MapType) dataType2).valueType();
            }
        }
    }

    public Expression normalize(Expression expression) {
        if (!org$apache$spark$sql$catalyst$optimizer$NormalizeFloatingNumbers$$needNormalize(expression)) {
            return expression;
        }
        if (expression instanceof Alias) {
            Alias alias = (Alias) expression;
            return alias.withNewChildren(new $colon.colon(normalize(alias.child()), Nil$.MODULE$));
        }
        if (expression instanceof CreateNamedStruct) {
            return new CreateNamedStruct((Seq) ((CreateNamedStruct) expression).children().map(expression2 -> {
                return MODULE$.normalize(expression2);
            }));
        }
        if (expression instanceof CreateArray) {
            CreateArray createArray = (CreateArray) expression;
            Seq<Expression> children = createArray.children();
            return new CreateArray((Seq) children.map(expression3 -> {
                return MODULE$.normalize(expression3);
            }), createArray.useStringTypeWhenEmpty());
        }
        if (expression instanceof CreateMap) {
            CreateMap createMap = (CreateMap) expression;
            Seq<Expression> children2 = createMap.children();
            return new CreateMap((Seq) children2.map(expression4 -> {
                return MODULE$.normalize(expression4);
            }), createMap.useStringTypeWhenEmpty());
        }
        DataType mo283dataType = expression.mo283dataType();
        FloatType$ floatType$ = FloatType$.MODULE$;
        if (mo283dataType != null ? !mo283dataType.equals(floatType$) : floatType$ != null) {
            DataType mo283dataType2 = expression.mo283dataType();
            DoubleType$ doubleType$ = DoubleType$.MODULE$;
            if (mo283dataType2 != null ? !mo283dataType2.equals(doubleType$) : doubleType$ != null) {
                if (expression instanceof If) {
                    If r0 = (If) expression;
                    return new If(r0.predicate(), normalize(r0.trueValue()), normalize(r0.falseValue()));
                }
                if (expression instanceof CaseWhen) {
                    CaseWhen caseWhen = (CaseWhen) expression;
                    return new CaseWhen((Seq) caseWhen.branches().map(tuple2 -> {
                        return new Tuple2(tuple2._1(), MODULE$.normalize((Expression) tuple2._2()));
                    }), caseWhen.elseValue().map(expression5 -> {
                        return MODULE$.normalize(expression5);
                    }));
                }
                if (expression instanceof Coalesce) {
                    return new Coalesce((Seq) ((Coalesce) expression).children().map(expression6 -> {
                        return MODULE$.normalize(expression6);
                    }));
                }
                if (expression.mo283dataType() instanceof StructType) {
                    CreateNamedStruct createNamedStruct = new CreateNamedStruct(ArrayImplicits$.MODULE$.SparkArrayOps(ArrayOps$.MODULE$.flatten$extension(Predef$.MODULE$.refArrayOps((Seq[]) ArrayOps$.MODULE$.map$extension(Predef$.MODULE$.refArrayOps(ArrayOps$.MODULE$.zipWithIndex$extension(Predef$.MODULE$.refArrayOps(expression.mo283dataType().fieldNames()))), tuple22 -> {
                        if (tuple22 == null) {
                            throw new MatchError(tuple22);
                        }
                        return new $colon.colon(Literal$.MODULE$.apply((String) tuple22._1()), new $colon.colon(MODULE$.normalize(new GetStructField(expression, tuple22._2$mcI$sp(), GetStructField$.MODULE$.apply$default$3())), Nil$.MODULE$));
                    }, ClassTag$.MODULE$.apply(Seq.class))), Predef$.MODULE$.$conforms(), ClassTag$.MODULE$.apply(Expression.class))).toImmutableArraySeq());
                    return new KnownFloatingPointNormalized(new If(new IsNull(expression), new Literal(null, createNamedStruct.mo283dataType()), createNamedStruct));
                }
                if (expression.mo283dataType() instanceof ArrayType) {
                    ArrayType mo283dataType3 = expression.mo283dataType();
                    if (!(mo283dataType3 instanceof ArrayType)) {
                        throw new MatchError(mo283dataType3);
                    }
                    ArrayType arrayType = mo283dataType3;
                    Tuple2 tuple23 = new Tuple2(arrayType.elementType(), BoxesRunTime.boxToBoolean(arrayType.containsNull()));
                    NamedLambdaVariable namedLambdaVariable = new NamedLambdaVariable("arg", (DataType) tuple23._1(), tuple23._2$mcZ$sp(), NamedLambdaVariable$.MODULE$.apply$default$4(), NamedLambdaVariable$.MODULE$.apply$default$5());
                    return new KnownFloatingPointNormalized(new ArrayTransform(expression, new LambdaFunction(normalize(namedLambdaVariable), new $colon.colon(namedLambdaVariable, Nil$.MODULE$), LambdaFunction$.MODULE$.apply$default$3())));
                }
                if (!(expression.mo283dataType() instanceof MapType)) {
                    throw SparkException$.MODULE$.internalError("fail to normalize " + expression);
                }
                MapType mo283dataType4 = expression.mo283dataType();
                if (!(mo283dataType4 instanceof MapType)) {
                    throw new MatchError(mo283dataType4);
                }
                MapType mapType = mo283dataType4;
                Tuple3 tuple3 = new Tuple3(mapType.keyType(), mapType.valueType(), BoxesRunTime.boxToBoolean(mapType.valueContainsNull()));
                DataType dataType = (DataType) tuple3._1();
                DataType dataType2 = (DataType) tuple3._2();
                boolean unboxToBoolean = BoxesRunTime.unboxToBoolean(tuple3._3());
                NamedLambdaVariable namedLambdaVariable2 = new NamedLambdaVariable("arg", dataType, unboxToBoolean, NamedLambdaVariable$.MODULE$.apply$default$4(), NamedLambdaVariable$.MODULE$.apply$default$5());
                NamedLambdaVariable namedLambdaVariable3 = new NamedLambdaVariable("arg", dataType2, unboxToBoolean, NamedLambdaVariable$.MODULE$.apply$default$4(), NamedLambdaVariable$.MODULE$.apply$default$5());
                return new KnownFloatingPointNormalized(new TransformValues(expression, new LambdaFunction(normalize(namedLambdaVariable3), new $colon.colon(namedLambdaVariable2, new $colon.colon(namedLambdaVariable3, Nil$.MODULE$)), LambdaFunction$.MODULE$.apply$default$3())));
            }
        }
        return new KnownFloatingPointNormalized(new NormalizeNaNAndZero(expression));
    }

    public Function1<Object, Object> FLOAT_NORMALIZER() {
        return FLOAT_NORMALIZER;
    }

    public Function1<Object, Object> DOUBLE_NORMALIZER() {
        return DOUBLE_NORMALIZER;
    }

    public static final /* synthetic */ boolean $anonfun$apply$1(TreePatternBits treePatternBits) {
        return treePatternBits.containsAnyPattern(ScalaRunTime$.MODULE$.wrapRefArray(new Enumeration.Value[]{TreePattern$.MODULE$.WINDOW(), TreePattern$.MODULE$.JOIN()}));
    }

    public static final /* synthetic */ boolean $anonfun$needNormalize$1(StructField structField) {
        return MODULE$.needNormalize(structField.dataType());
    }

    private NormalizeFloatingNumbers$() {
    }
}
