/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.ksql.execution.util;

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.execution.expression.tree.ArithmeticBinaryExpression;
import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
import io.confluent.ksql.execution.expression.tree.BetweenPredicate;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.Cast;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.CreateArrayExpression;
import io.confluent.ksql.execution.expression.tree.CreateMapExpression;
import io.confluent.ksql.execution.expression.tree.CreateStructExpression;
import io.confluent.ksql.execution.expression.tree.DecimalLiteral;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.DoubleLiteral;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.ExpressionVisitor;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.InListExpression;
import io.confluent.ksql.execution.expression.tree.InPredicate;
import io.confluent.ksql.execution.expression.tree.IntegerLiteral;
import io.confluent.ksql.execution.expression.tree.IsNotNullPredicate;
import io.confluent.ksql.execution.expression.tree.IsNullPredicate;
import io.confluent.ksql.execution.expression.tree.LikePredicate;
import io.confluent.ksql.execution.expression.tree.LogicalBinaryExpression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.NotExpression;
import io.confluent.ksql.execution.expression.tree.NullLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.SearchedCaseExpression;
import io.confluent.ksql.execution.expression.tree.SimpleCaseExpression;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.expression.tree.SubscriptExpression;
import io.confluent.ksql.execution.expression.tree.TimeLiteral;
import io.confluent.ksql.execution.expression.tree.TimestampLiteral;
import io.confluent.ksql.execution.expression.tree.Type;
import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.WhenClause;
import io.confluent.ksql.execution.function.UdafUtil;
import io.confluent.ksql.execution.util.ComparisonUtil;
import io.confluent.ksql.function.AggregateFunctionInitArguments;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.function.KsqlAggregateFunction;
import io.confluent.ksql.function.KsqlFunctionException;
import io.confluent.ksql.function.KsqlTableFunction;
import io.confluent.ksql.function.UdfFactory;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.Field;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlStruct;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.VisitorUtil;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class ExpressionTypeManager {
    private final LogicalSchema schema;
    private final FunctionRegistry functionRegistry;

    public ExpressionTypeManager(LogicalSchema schema, FunctionRegistry functionRegistry) {
        this.schema = Objects.requireNonNull(schema, "schema");
        this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry");
    }

    public SqlType getExpressionSqlType(Expression expression) {
        ExpressionTypeContext expressionTypeContext = new ExpressionTypeContext();
        new Visitor().process(expression, expressionTypeContext);
        return expressionTypeContext.getSqlType();
    }

    private class Visitor
    implements ExpressionVisitor<Void, ExpressionTypeContext> {
        private Visitor() {
        }

        @Override
        public Void visitArithmeticBinary(ArithmeticBinaryExpression node, ExpressionTypeContext expressionTypeContext) {
            this.process(node.getLeft(), expressionTypeContext);
            SqlType leftType = expressionTypeContext.getSqlType();
            this.process(node.getRight(), expressionTypeContext);
            SqlType rightType = expressionTypeContext.getSqlType();
            SqlType resultType = node.getOperator().resultType(leftType, rightType);
            expressionTypeContext.setSqlType(resultType);
            return null;
        }

        @Override
        public Void visitArithmeticUnary(ArithmeticUnaryExpression node, ExpressionTypeContext context) {
            this.process(node.getValue(), context);
            return null;
        }

        @Override
        public Void visitNotExpression(NotExpression node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitCast(Cast node, ExpressionTypeContext expressionTypeContext) {
            SqlType sqlType = node.getType().getSqlType();
            if (!sqlType.supportsCast()) {
                throw new KsqlFunctionException("Only casts to primitive types or decimals are supported: " + sqlType);
            }
            expressionTypeContext.setSqlType(sqlType);
            return null;
        }

        @Override
        public Void visitComparisonExpression(ComparisonExpression node, ExpressionTypeContext expressionTypeContext) {
            this.process(node.getLeft(), expressionTypeContext);
            SqlType leftSchema = expressionTypeContext.getSqlType();
            this.process(node.getRight(), expressionTypeContext);
            SqlType rightSchema = expressionTypeContext.getSqlType();
            ComparisonUtil.isValidComparison(leftSchema, node.getType(), rightSchema);
            expressionTypeContext.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitBetweenPredicate(BetweenPredicate node, ExpressionTypeContext context) {
            context.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitColumnReference(UnqualifiedColumnReferenceExp node, ExpressionTypeContext expressionTypeContext) {
            Optional possibleColumn = ExpressionTypeManager.this.schema.findValueColumn(node.getReference());
            Column schemaColumn = (Column)possibleColumn.orElseThrow(() -> new KsqlException("Unknown column " + (Object)((Object)node) + "."));
            expressionTypeContext.setSqlType(schemaColumn.type());
            return null;
        }

        @Override
        public Void visitQualifiedColumnReference(QualifiedColumnReferenceExp node, ExpressionTypeContext expressionTypeContext) {
            throw new IllegalStateException("Qualified column references must be resolved to unqualified reference before type can be resolved");
        }

        @Override
        public Void visitDereferenceExpression(DereferenceExpression node, ExpressionTypeContext expressionTypeContext) {
            this.process(node.getBase(), expressionTypeContext);
            SqlType sqlType = expressionTypeContext.getSqlType();
            if (!(sqlType instanceof SqlStruct)) {
                throw new IllegalStateException("Expected STRUCT type, got: " + sqlType);
            }
            SqlStruct structType = (SqlStruct)sqlType;
            String fieldName = node.getFieldName();
            Field structField = (Field)structType.field(fieldName).orElseThrow(() -> new KsqlException("Could not find field '" + fieldName + "' in '" + (Object)((Object)node.getBase()) + "'."));
            expressionTypeContext.setSqlType(structField.type());
            return null;
        }

        @Override
        public Void visitStringLiteral(StringLiteral node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.STRING);
            return null;
        }

        @Override
        public Void visitBooleanLiteral(BooleanLiteral node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitLongLiteral(LongLiteral node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.BIGINT);
            return null;
        }

        @Override
        public Void visitIntegerLiteral(IntegerLiteral node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.INTEGER);
            return null;
        }

        @Override
        public Void visitDoubleLiteral(DoubleLiteral node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.DOUBLE);
            return null;
        }

        @Override
        public Void visitNullLiteral(NullLiteral node, ExpressionTypeContext context) {
            context.setSqlType(null);
            return null;
        }

        @Override
        public Void visitLikePredicate(LikePredicate node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitIsNotNullPredicate(IsNotNullPredicate node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitIsNullPredicate(IsNullPredicate node, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType((SqlType)SqlTypes.BOOLEAN);
            return null;
        }

        @Override
        public Void visitSearchedCaseExpression(SearchedCaseExpression node, ExpressionTypeContext context) {
            Optional<SqlType> whenType = this.validateWhenClauses(node.getWhenClauses(), context);
            Optional<SqlType> defaultType = node.getDefaultValue().map(ExpressionTypeManager.this::getExpressionSqlType);
            if (whenType.isPresent() && defaultType.isPresent()) {
                if (!whenType.get().equals(defaultType.get())) {
                    throw new KsqlException("Invalid Case expression. Type for the default clause should be the same as for 'THEN' clauses." + System.lineSeparator() + "THEN type: " + whenType.get() + "." + System.lineSeparator() + "DEFAULT type: " + defaultType.get() + ".");
                }
                context.setSqlType(whenType.get());
            } else if (whenType.isPresent()) {
                context.setSqlType(whenType.get());
            } else if (defaultType.isPresent()) {
                context.setSqlType(defaultType.get());
            } else {
                throw new KsqlException("Invalid Case expression. All case branches have NULL type");
            }
            return null;
        }

        @Override
        public Void visitSubscriptExpression(SubscriptExpression node, ExpressionTypeContext expressionTypeContext) {
            SqlType valueType;
            this.process(node.getBase(), expressionTypeContext);
            SqlType arrayMapType = expressionTypeContext.getSqlType();
            if (arrayMapType instanceof SqlMap) {
                valueType = ((SqlMap)arrayMapType).getValueType();
            } else if (arrayMapType instanceof SqlArray) {
                valueType = ((SqlArray)arrayMapType).getItemType();
            } else {
                throw new UnsupportedOperationException("Unsupported container type: " + arrayMapType);
            }
            expressionTypeContext.setSqlType(valueType);
            return null;
        }

        @Override
        public Void visitCreateArrayExpression(CreateArrayExpression exp, ExpressionTypeContext context) {
            if (exp.getValues().isEmpty()) {
                throw new KsqlException("Array constructor cannot be empty. Please supply at least one element (see https://github.com/confluentinc/ksql/issues/4239).");
            }
            List sqlTypes = exp.getValues().stream().map(val -> {
                this.process((Expression)((Object)val), context);
                return context.getSqlType();
            }).filter(Objects::nonNull).collect(Collectors.toList());
            if (sqlTypes.size() == 0) {
                throw new KsqlException("Cannot construct an array with all NULL elements (see https://github.com/confluentinc/ksql/issues/4239). As a workaround, you may cast a NULL value to the desired type.");
            }
            if (new HashSet(sqlTypes).size() != 1) {
                throw new KsqlException(String.format("Cannot construct an array with mismatching types (%s) from expression %s.", new Object[]{sqlTypes, exp}));
            }
            context.setSqlType((SqlType)SqlArray.of((SqlType)((SqlType)sqlTypes.get(0))));
            return null;
        }

        @Override
        public Void visitCreateMapExpression(CreateMapExpression exp, ExpressionTypeContext context) {
            if (exp.getMap().isEmpty()) {
                throw new KsqlException("Map constructor cannot be empty. Please supply at least one key value pair (see https://github.com/confluentinc/ksql/issues/4239).");
            }
            List keyTypes = exp.getMap().keySet().stream().map(key -> {
                this.process((Expression)((Object)key), context);
                return context.getSqlType();
            }).collect(Collectors.toList());
            if (keyTypes.stream().anyMatch(type -> !SqlTypes.STRING.equals(type))) {
                throw new KsqlException("Only STRING keys are supported in maps but got: " + keyTypes);
            }
            List valueTypes = exp.getMap().values().stream().map(val -> {
                this.process((Expression)((Object)val), context);
                return context.getSqlType();
            }).distinct().collect(Collectors.toList());
            if (valueTypes.size() != 1) {
                throw new KsqlException(String.format("Cannot construct a map with mismatching value types (%s) from expression %s.", new Object[]{valueTypes, exp}));
            }
            if (valueTypes.get(0) == null) {
                throw new KsqlException("Cannot construct MAP with NULL values. As a workaround, you may cast a NULL value to the desired type.");
            }
            context.setSqlType((SqlType)SqlMap.of((SqlType)((SqlType)valueTypes.get(0))));
            return null;
        }

        @Override
        public Void visitStructExpression(CreateStructExpression exp, ExpressionTypeContext context) {
            SqlStruct.Builder builder = SqlStruct.builder();
            for (CreateStructExpression.Field field : exp.getFields()) {
                this.process(field.getValue(), context);
                builder.field(field.getName(), context.getSqlType());
            }
            context.setSqlType((SqlType)builder.build());
            return null;
        }

        @Override
        public Void visitFunctionCall(FunctionCall node, ExpressionTypeContext expressionTypeContext) {
            if (ExpressionTypeManager.this.functionRegistry.isAggregate(node.getName())) {
                SqlType schema = node.getArguments().isEmpty() ? FunctionRegistry.DEFAULT_FUNCTION_ARG_SCHEMA : ExpressionTypeManager.this.getExpressionSqlType(node.getArguments().get(0));
                AggregateFunctionInitArguments args = UdafUtil.createAggregateFunctionInitArgs(0, node);
                KsqlAggregateFunction aggFunc = ExpressionTypeManager.this.functionRegistry.getAggregateFunction(node.getName(), schema, args);
                expressionTypeContext.setSqlType(aggFunc.returnType());
                return null;
            }
            if (ExpressionTypeManager.this.functionRegistry.isTableFunction(node.getName())) {
                ImmutableList argumentTypes = node.getArguments().isEmpty() ? ImmutableList.of((Object)FunctionRegistry.DEFAULT_FUNCTION_ARG_SCHEMA) : node.getArguments().stream().map(ExpressionTypeManager.this::getExpressionSqlType).collect(Collectors.toList());
                KsqlTableFunction tableFunction = ExpressionTypeManager.this.functionRegistry.getTableFunction(node.getName(), (List)argumentTypes);
                expressionTypeContext.setSqlType(tableFunction.getReturnType((List)argumentTypes));
                return null;
            }
            UdfFactory udfFactory = ExpressionTypeManager.this.functionRegistry.getUdfFactory(node.getName());
            ArrayList<SqlType> argTypes = new ArrayList<SqlType>();
            for (Expression expression : node.getArguments()) {
                this.process(expression, expressionTypeContext);
                argTypes.add(expressionTypeContext.getSqlType());
            }
            SqlType returnSchema = udfFactory.getFunction(argTypes).getReturnType(argTypes);
            expressionTypeContext.setSqlType(returnSchema);
            return null;
        }

        @Override
        public Void visitLogicalBinaryExpression(LogicalBinaryExpression node, ExpressionTypeContext context) {
            this.process(node.getLeft(), context);
            this.process(node.getRight(), context);
            return null;
        }

        @Override
        public Void visitType(Type type, ExpressionTypeContext expressionTypeContext) {
            throw VisitorUtil.illegalState((Object)this, (Object)((Object)type));
        }

        @Override
        public Void visitTimeLiteral(TimeLiteral timeLiteral, ExpressionTypeContext expressionTypeContext) {
            throw VisitorUtil.unsupportedOperation((Object)this, (Object)((Object)timeLiteral));
        }

        @Override
        public Void visitTimestampLiteral(TimestampLiteral timestampLiteral, ExpressionTypeContext expressionTypeContext) {
            throw VisitorUtil.unsupportedOperation((Object)this, (Object)((Object)timestampLiteral));
        }

        @Override
        public Void visitDecimalLiteral(DecimalLiteral decimalLiteral, ExpressionTypeContext expressionTypeContext) {
            expressionTypeContext.setSqlType(DecimalUtil.fromValue((BigDecimal)decimalLiteral.getValue()));
            return null;
        }

        @Override
        public Void visitSimpleCaseExpression(SimpleCaseExpression simpleCaseExpression, ExpressionTypeContext expressionTypeContext) {
            throw VisitorUtil.unsupportedOperation((Object)this, (Object)((Object)simpleCaseExpression));
        }

        @Override
        public Void visitInListExpression(InListExpression inListExpression, ExpressionTypeContext expressionTypeContext) {
            throw VisitorUtil.unsupportedOperation((Object)this, (Object)((Object)inListExpression));
        }

        @Override
        public Void visitInPredicate(InPredicate inPredicate, ExpressionTypeContext expressionTypeContext) {
            throw VisitorUtil.unsupportedOperation((Object)this, (Object)((Object)inPredicate));
        }

        @Override
        public Void visitWhenClause(WhenClause whenClause, ExpressionTypeContext expressionTypeContext) {
            throw VisitorUtil.illegalState((Object)this, (Object)((Object)whenClause));
        }

        private Optional<SqlType> validateWhenClauses(List<WhenClause> whenClauses, ExpressionTypeContext context) {
            Optional<SqlType> previousResult = Optional.empty();
            for (WhenClause whenClause : whenClauses) {
                this.process(whenClause.getOperand(), context);
                SqlType operandType = context.getSqlType();
                if (operandType.baseType() != SqlBaseType.BOOLEAN) {
                    throw new KsqlException("WHEN operand type should be boolean." + System.lineSeparator() + "Type for '" + (Object)((Object)whenClause.getOperand()) + "' is " + operandType);
                }
                this.process(whenClause.getResult(), context);
                SqlType resultType = context.getSqlType();
                if (resultType == null) continue;
                if (!previousResult.isPresent()) {
                    previousResult = Optional.of(resultType);
                    continue;
                }
                if (previousResult.get().equals(resultType)) continue;
                throw new KsqlException("Invalid Case expression. Type for all 'THEN' clauses should be the same." + System.lineSeparator() + "THEN expression '" + (Object)((Object)whenClause) + "' has type: " + resultType + "." + System.lineSeparator() + "Previous THEN expression(s) type: " + previousResult.get() + ".");
            }
            return previousResult;
        }
    }

    private static class ExpressionTypeContext {
        private SqlType sqlType;

        private ExpressionTypeContext() {
        }

        SqlType getSqlType() {
            return this.sqlType;
        }

        void setSqlType(SqlType sqlType) {
            this.sqlType = sqlType;
        }
    }
}

