/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.ksql.execution.interpreter;

import io.confluent.ksql.execution.expression.tree.ArithmeticUnaryExpression;
import io.confluent.ksql.execution.interpreter.CastInterpreter;
import io.confluent.ksql.execution.interpreter.terms.ArithmeticBinaryTerm;
import io.confluent.ksql.execution.interpreter.terms.ArithmeticUnaryTerm;
import io.confluent.ksql.execution.interpreter.terms.CastTerm;
import io.confluent.ksql.execution.interpreter.terms.Term;
import io.confluent.ksql.schema.Operator;
import io.confluent.ksql.schema.ksql.types.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlDecimal;
import io.confluent.ksql.schema.ksql.types.SqlType;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;

public final class ArithmeticInterpreter {
    private ArithmeticInterpreter() {
    }

    public static Term doUnaryArithmetic(ArithmeticUnaryExpression.Sign sign, Term value) {
        ArithmeticUnaryTerm.ArithmeticUnaryFunction function;
        switch (sign) {
            case MINUS: {
                function = ArithmeticInterpreter.getUnaryMinusFunction(value);
                break;
            }
            case PLUS: {
                function = ArithmeticInterpreter.getUnaryPlusFunction(value);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Unsupported sign: " + (Object)((Object)sign));
            }
        }
        return new ArithmeticUnaryTerm(value, function);
    }

    public static Term doBinaryArithmetic(Operator operator, Term left, Term right, SqlType resultType, KsqlConfig ksqlConfig) {
        if (resultType.baseType() == SqlBaseType.DECIMAL) {
            SqlDecimal decimal = (SqlDecimal)resultType;
            CastTerm leftTerm = CastInterpreter.cast(left, left.getSqlType(), (SqlType)DecimalUtil.toSqlDecimal((SqlType)left.getSqlType()), ksqlConfig);
            CastTerm rightTerm = CastInterpreter.cast(right, right.getSqlType(), (SqlType)DecimalUtil.toSqlDecimal((SqlType)right.getSqlType()), ksqlConfig);
            TypedArithmeticBinaryFunction<BigDecimal> fn = ArithmeticInterpreter.getDecimalFunction(decimal, operator);
            return new ArithmeticBinaryTerm(leftTerm, rightTerm, (o1, o2) -> fn.doFunction((BigDecimal)o1, (BigDecimal)o2), resultType);
        }
        Term leftTerm = left.getSqlType().baseType() == SqlBaseType.DECIMAL ? CastInterpreter.cast(left, left.getSqlType(), (SqlType)SqlTypes.DOUBLE, ksqlConfig) : left;
        Term rightTerm = right.getSqlType().baseType() == SqlBaseType.DECIMAL ? CastInterpreter.cast(right, right.getSqlType(), (SqlType)SqlTypes.DOUBLE, ksqlConfig) : right;
        return new ArithmeticBinaryTerm(leftTerm, rightTerm, ArithmeticInterpreter.getNonDecimalArithmeticFunction(operator, leftTerm.getSqlType(), rightTerm.getSqlType()), resultType);
    }

    private static ArithmeticBinaryTerm.ArithmeticBinaryFunction getNonDecimalArithmeticFunction(Operator operator, SqlType leftType, SqlType rightType) {
        SqlBaseType leftBaseType = leftType.baseType();
        SqlBaseType rightBaseType = rightType.baseType();
        if (leftBaseType == SqlBaseType.STRING && rightBaseType == SqlBaseType.STRING) {
            return (o1, o2) -> (String)o1 + (String)o2;
        }
        if (leftBaseType == SqlBaseType.DOUBLE || rightBaseType == SqlBaseType.DOUBLE) {
            TypedArithmeticBinaryFunction<Double> fn = ArithmeticInterpreter.getDoubleFunction(operator);
            CastTerm.ComparableCastFunction<Double> castLeft = CastInterpreter.castToDoubleFunction(leftType);
            CastTerm.ComparableCastFunction<Double> castRight = CastInterpreter.castToDoubleFunction(rightType);
            return (o1, o2) -> (Double)fn.doFunction((Double)castLeft.cast(o1), (Double)castRight.cast(o2));
        }
        if (leftBaseType == SqlBaseType.BIGINT || rightBaseType == SqlBaseType.BIGINT) {
            TypedArithmeticBinaryFunction<Long> fn = ArithmeticInterpreter.getLongFunction(operator);
            CastTerm.ComparableCastFunction<Long> castLeft = CastInterpreter.castToLongFunction(leftType);
            CastTerm.ComparableCastFunction<Long> castRight = CastInterpreter.castToLongFunction(rightType);
            return (o1, o2) -> (Long)fn.doFunction((Long)castLeft.cast(o1), (Long)castRight.cast(o2));
        }
        if (leftBaseType == SqlBaseType.INTEGER || rightBaseType == SqlBaseType.INTEGER) {
            TypedArithmeticBinaryFunction<Integer> fn = ArithmeticInterpreter.getIntegerFunction(operator);
            CastTerm.ComparableCastFunction<Integer> castLeft = CastInterpreter.castToIntegerFunction(leftType);
            CastTerm.ComparableCastFunction<Integer> castRight = CastInterpreter.castToIntegerFunction(rightType);
            return (o1, o2) -> (Integer)fn.doFunction((Integer)castLeft.cast(o1), (Integer)castRight.cast(o2));
        }
        throw new KsqlException("Can't do arithmetic for types " + leftType + " and " + rightType);
    }

    private static ArithmeticUnaryTerm.ArithmeticUnaryFunction getUnaryMinusFunction(Term term) {
        if (term.getSqlType().baseType() == SqlBaseType.DECIMAL) {
            return o -> ((BigDecimal)o).negate(new MathContext(((SqlDecimal)term.getSqlType()).getPrecision(), RoundingMode.UNNECESSARY));
        }
        if (term.getSqlType().baseType() == SqlBaseType.DOUBLE) {
            return o -> -((Double)o).doubleValue();
        }
        if (term.getSqlType().baseType() == SqlBaseType.INTEGER) {
            return o -> -((Integer)o).intValue();
        }
        if (term.getSqlType().baseType() == SqlBaseType.BIGINT) {
            return o -> -((Long)o).longValue();
        }
        throw new UnsupportedOperationException("Negation on unsupported type: " + term.getSqlType());
    }

    private static ArithmeticUnaryTerm.ArithmeticUnaryFunction getUnaryPlusFunction(Term term) {
        if (term.getSqlType().baseType() == SqlBaseType.DECIMAL) {
            return o -> ((BigDecimal)o).plus(new MathContext(((SqlDecimal)term.getSqlType()).getPrecision(), RoundingMode.UNNECESSARY));
        }
        if (term.getSqlType().baseType() == SqlBaseType.DOUBLE || term.getSqlType().baseType() == SqlBaseType.INTEGER || term.getSqlType().baseType() == SqlBaseType.BIGINT) {
            return o -> o;
        }
        throw new UnsupportedOperationException("Unary plus on unsupported type: " + term.getSqlType());
    }

    private static TypedArithmeticBinaryFunction<Double> getDoubleFunction(Operator operator) {
        switch (operator) {
            case ADD: {
                return (a, b) -> a + b;
            }
            case SUBTRACT: {
                return (a, b) -> a - b;
            }
            case MULTIPLY: {
                return (a, b) -> a * b;
            }
            case DIVIDE: {
                return (a, b) -> a / b;
            }
            case MODULUS: {
                return (a, b) -> a % b;
            }
        }
        throw new KsqlException("Unknown operator " + operator);
    }

    private static TypedArithmeticBinaryFunction<Integer> getIntegerFunction(Operator operator) {
        switch (operator) {
            case ADD: {
                return (a, b) -> a + b;
            }
            case SUBTRACT: {
                return (a, b) -> a - b;
            }
            case MULTIPLY: {
                return (a, b) -> a * b;
            }
            case DIVIDE: {
                return (a, b) -> a / b;
            }
            case MODULUS: {
                return (a, b) -> a % b;
            }
        }
        throw new KsqlException("Unknown operator " + operator);
    }

    private static TypedArithmeticBinaryFunction<Long> getLongFunction(Operator operator) {
        switch (operator) {
            case ADD: {
                return (a, b) -> a + b;
            }
            case SUBTRACT: {
                return (a, b) -> a - b;
            }
            case MULTIPLY: {
                return (a, b) -> a * b;
            }
            case DIVIDE: {
                return (a, b) -> a / b;
            }
            case MODULUS: {
                return (a, b) -> a % b;
            }
        }
        throw new KsqlException("Unknown operator " + operator);
    }

    private static TypedArithmeticBinaryFunction<BigDecimal> getDecimalFunction(SqlDecimal decimal, Operator operator) {
        MathContext mc = new MathContext(decimal.getPrecision(), RoundingMode.UNNECESSARY);
        switch (operator) {
            case ADD: {
                return (a, b) -> a.add((BigDecimal)b, mc).setScale(decimal.getScale());
            }
            case SUBTRACT: {
                return (a, b) -> a.subtract((BigDecimal)b, mc).setScale(decimal.getScale());
            }
            case MULTIPLY: {
                return (a, b) -> a.multiply((BigDecimal)b, mc).setScale(decimal.getScale());
            }
            case DIVIDE: {
                return (a, b) -> a.divide((BigDecimal)b, mc).setScale(decimal.getScale());
            }
            case MODULUS: {
                return (a, b) -> a.remainder((BigDecimal)b, mc).setScale(decimal.getScale());
            }
        }
        throw new KsqlException("DECIMAL operator not supported: " + operator);
    }

    private static interface TypedArithmeticBinaryFunction<T> {
        public T doFunction(T var1, T var2);
    }
}

