/*
 * Decompiled with CFR 0.152.
 */
package org.apache.gluten.utils;

import org.apache.gluten.exception.GlutenNotSupportException;
import org.apache.gluten.sql.shims.SparkShimLoader$;
import org.apache.spark.sql.catalyst.expressions.Add;
import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic;
import org.apache.spark.sql.catalyst.expressions.Cast;
import org.apache.spark.sql.catalyst.expressions.Cast$;
import org.apache.spark.sql.catalyst.expressions.Divide;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.expressions.Multiply;
import org.apache.spark.sql.catalyst.expressions.Pmod;
import org.apache.spark.sql.catalyst.expressions.PromotePrecision;
import org.apache.spark.sql.catalyst.expressions.Remainder;
import org.apache.spark.sql.catalyst.expressions.Subtract;
import org.apache.spark.sql.internal.SQLConf$;
import org.apache.spark.sql.types.ByteType$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.IntegerType$;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.ShortType$;
import org.apache.spark.sql.utils.DecimalTypeUtil$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.immutable.;
import scala.collection.immutable.List;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Seq;
import scala.runtime.BoxesRunTime;

public final class DecimalArithmeticUtil$ {
    public static final DecimalArithmeticUtil$ MODULE$ = new DecimalArithmeticUtil$();
    private static final int MIN_ADJUSTED_SCALE = 6;
    private static final int MAX_PRECISION = 38;
    private static final int MAX_SCALE = 38;

    public int MIN_ADJUSTED_SCALE() {
        return MIN_ADJUSTED_SCALE;
    }

    public int MAX_PRECISION() {
        return MAX_PRECISION;
    }

    public int MAX_SCALE() {
        return MAX_SCALE;
    }

    public DecimalType getResultType(BinaryArithmetic expr, DecimalType type1, DecimalType type2) {
        boolean allowPrecisionLoss = SQLConf$.MODULE$.get().decimalOperationsAllowPrecisionLoss();
        int resultScale = 0;
        int resultPrecision = 0;
        BinaryArithmetic binaryArithmetic = expr;
        if (binaryArithmetic instanceof Add) {
            resultScale = Math.max(type1.scale(), type2.scale());
            resultPrecision = resultScale + Math.max(type1.precision() - type1.scale(), type2.precision() - type2.scale()) + 1;
        } else if (binaryArithmetic instanceof Subtract) {
            resultScale = Math.max(type1.scale(), type2.scale());
            resultPrecision = resultScale + Math.max(type1.precision() - type1.scale(), type2.precision() - type2.scale()) + 1;
        } else if (binaryArithmetic instanceof Multiply) {
            resultScale = type1.scale() + type2.scale();
            resultPrecision = type1.precision() + type2.precision() + 1;
        } else if (binaryArithmetic instanceof Divide) {
            if (allowPrecisionLoss) {
                resultScale = Math.max(this.MIN_ADJUSTED_SCALE(), type1.scale() + type2.precision() + 1);
                resultPrecision = type1.precision() - type1.scale() + type2.scale() + resultScale;
            } else {
                int decDig;
                int intDig = Math.min(this.MAX_SCALE(), type1.precision() - type1.scale() + type2.scale());
                int diff = intDig + (decDig = Math.min(this.MAX_SCALE(), Math.max(6, type1.scale() + type2.precision() + 1))) - this.MAX_SCALE();
                if (diff > 0) {
                    intDig = this.MAX_SCALE() - (decDig -= diff / 2 + 1);
                }
                resultPrecision = intDig + decDig;
                resultScale = decDig;
            }
        } else {
            throw new GlutenNotSupportException(binaryArithmetic + " is not supported.");
        }
        if (allowPrecisionLoss) {
            return DecimalTypeUtil$.MODULE$.adjustPrecisionScale(resultPrecision, resultScale);
        }
        return this.bounded(resultPrecision, resultScale);
    }

    public DecimalType bounded(int precision, int scale) {
        return new DecimalType(Math.min(precision, this.MAX_PRECISION()), Math.min(scale, this.MAX_SCALE()));
    }

    public boolean isDecimalArithmetic(BinaryArithmetic b) {
        if (((Expression)b.left()).dataType() instanceof DecimalType && ((Expression)b.right()).dataType() instanceof DecimalType) {
            BinaryArithmetic binaryArithmetic = b;
            return binaryArithmetic instanceof Divide ? true : (binaryArithmetic instanceof Multiply ? true : (binaryArithmetic instanceof Add ? true : (binaryArithmetic instanceof Subtract ? true : (binaryArithmetic instanceof Remainder ? true : binaryArithmetic instanceof Pmod))));
        }
        return false;
    }

    private Tuple2<Integer, Integer> getNewPrecisionScale(Decimal dec) {
        String input = dec.abs().toJavaBigDecimal().toPlainString();
        int dotIndex = input.indexOf(".");
        if (dotIndex == -1) {
            return new Tuple2((Object)Predef$.MODULE$.int2Integer(input.length()), (Object)Predef$.MODULE$.int2Integer(0));
        }
        if (dec.toBigDecimal().isValidLong()) {
            return new Tuple2((Object)Predef$.MODULE$.int2Integer(dotIndex), (Object)Predef$.MODULE$.int2Integer(0));
        }
        return new Tuple2((Object)Predef$.MODULE$.int2Integer(dec.precision()), (Object)Predef$.MODULE$.int2Integer(dec.scale()));
    }

    public BinaryArithmetic rescaleLiteral(BinaryArithmetic arithmeticExpr) {
        if (arithmeticExpr.left() instanceof PromotePrecision && arithmeticExpr.right() instanceof Literal) {
            Literal lit = (Literal)arithmeticExpr.right();
            Object object = lit.value();
            if (object instanceof Decimal) {
                Decimal decimal = (Decimal)object;
                Tuple2<Integer, Integer> tuple2 = this.getNewPrecisionScale(decimal);
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Integer precision = (Integer)tuple2._1();
                Integer scale = (Integer)tuple2._2();
                Tuple2 tuple22 = new Tuple2((Object)precision, (Object)scale);
                Integer precision2 = (Integer)tuple22._1();
                Integer scale2 = (Integer)tuple22._2();
                if (!BoxesRunTime.equalsNumObject((Number)precision2, (Object)BoxesRunTime.boxToInteger((int)decimal.precision())) || !BoxesRunTime.equalsNumObject((Number)scale2, (Object)BoxesRunTime.boxToInteger((int)decimal.scale()))) {
                    return (BinaryArithmetic)arithmeticExpr.withNewChildren((Seq)new .colon.colon((Object)((Expression)arithmeticExpr.left()), (List)new .colon.colon((Object)new Cast((Expression)lit, (DataType)new DecimalType(Predef$.MODULE$.Integer2int(precision2), Predef$.MODULE$.Integer2int(scale2)), Cast$.MODULE$.apply$default$3(), Cast$.MODULE$.apply$default$4()), (List)Nil$.MODULE$)));
                }
                return arithmeticExpr;
            }
            return arithmeticExpr;
        }
        if (arithmeticExpr.right() instanceof PromotePrecision && arithmeticExpr.left() instanceof Literal) {
            Literal lit = (Literal)arithmeticExpr.left();
            Object object = lit.value();
            if (object instanceof Decimal) {
                Decimal decimal = (Decimal)object;
                Tuple2<Integer, Integer> tuple2 = this.getNewPrecisionScale(decimal);
                if (tuple2 == null) {
                    throw new MatchError(tuple2);
                }
                Integer precision = (Integer)tuple2._1();
                Integer scale = (Integer)tuple2._2();
                Tuple2 tuple23 = new Tuple2((Object)precision, (Object)scale);
                Integer precision3 = (Integer)tuple23._1();
                Integer scale3 = (Integer)tuple23._2();
                if (!BoxesRunTime.equalsNumObject((Number)precision3, (Object)BoxesRunTime.boxToInteger((int)decimal.precision())) || !BoxesRunTime.equalsNumObject((Number)scale3, (Object)BoxesRunTime.boxToInteger((int)decimal.scale()))) {
                    return (BinaryArithmetic)arithmeticExpr.withNewChildren((Seq)new .colon.colon((Object)new Cast((Expression)lit, (DataType)new DecimalType(Predef$.MODULE$.Integer2int(precision3), Predef$.MODULE$.Integer2int(scale3)), Cast$.MODULE$.apply$default$3(), Cast$.MODULE$.apply$default$4()), (List)new .colon.colon((Object)((Expression)arithmeticExpr.right()), (List)Nil$.MODULE$)));
                }
                return arithmeticExpr;
            }
            return arithmeticExpr;
        }
        return arithmeticExpr;
    }

    private boolean isPromoteCast(Expression expr) {
        Cast cast;
        PromotePrecision promotePrecision;
        Expression expression;
        Expression expression2 = expr;
        return expression2 instanceof PromotePrecision && (expression = (promotePrecision = (PromotePrecision)expression2).child()) instanceof Cast && (cast = (Cast)expression).dataType() instanceof DecimalType;
    }

    public Tuple2<Expression, Expression> rescaleCastForDecimal(Expression left, Expression right) {
        if (!this.isPromoteCast(left) && this.isPromoteCastIntegral(right)) {
            return this.doScale$1(left, right);
        }
        if (!this.isPromoteCast(right) && this.isPromoteCastIntegral(left)) {
            Tuple2 tuple2 = this.doScale$1(right, left);
            if (tuple2 == null) {
                throw new MatchError((Object)tuple2);
            }
            Expression r = (Expression)tuple2._1();
            Expression l = (Expression)tuple2._2();
            Tuple2 tuple22 = new Tuple2((Object)r, (Object)l);
            Expression r2 = (Expression)tuple22._1();
            Expression l2 = (Expression)tuple22._2();
            return new Tuple2((Object)l2, (Object)r2);
        }
        return new Tuple2((Object)left, (Object)right);
    }

    public Expression removeCastForDecimal(Expression arithmeticExpr) {
        PromotePrecision promotePrecision;
        Expression expression;
        Expression expression2 = arithmeticExpr;
        if (expression2 instanceof PromotePrecision && (expression = (promotePrecision = (PromotePrecision)expression2).child()) instanceof Cast) {
            Cast cast = (Cast)expression;
            Expression child = cast.child();
            if (cast.dataType() instanceof DecimalType && child.dataType() instanceof DecimalType) {
                return child;
            }
        }
        return arithmeticExpr;
    }

    private boolean isPromoteCastIntegral(Expression expr) {
        PromotePrecision promotePrecision;
        Expression expression;
        Expression expression2 = expr;
        if (expression2 instanceof PromotePrecision && (expression = (promotePrecision = (PromotePrecision)expression2).child()) instanceof Cast) {
            Cast cast = (Cast)expression;
            Expression child = cast.child();
            if (cast.dataType() instanceof DecimalType) {
                DataType dataType = child.dataType();
                return IntegerType$.MODULE$.equals(dataType) ? true : (ByteType$.MODULE$.equals(dataType) ? true : (ShortType$.MODULE$.equals(dataType) ? true : LongType$.MODULE$.equals(dataType)));
            }
        }
        return false;
    }

    private Expression rescaleCastForOneSide(Expression expr) {
        PromotePrecision promotePrecision;
        Expression expression;
        Expression expression2 = expr;
        if (expression2 instanceof PromotePrecision && (expression = (promotePrecision = (PromotePrecision)expression2).child()) instanceof Cast) {
            Cast cast = (Cast)expression;
            Expression child = cast.child();
            if (cast.dataType() instanceof DecimalType) {
                DataType dataType = child.dataType();
                if (IntegerType$.MODULE$.equals(dataType) ? true : (ByteType$.MODULE$.equals(dataType) ? true : ShortType$.MODULE$.equals(dataType))) {
                    return (Expression)promotePrecision.withNewChildren((Seq)new .colon.colon((Object)new Cast(child, (DataType)new DecimalType(10, 0), Cast$.MODULE$.apply$default$3(), Cast$.MODULE$.apply$default$4()), (List)Nil$.MODULE$));
                }
                if (LongType$.MODULE$.equals(dataType)) {
                    return (Expression)promotePrecision.withNewChildren((Seq)new .colon.colon((Object)new Cast(child, (DataType)new DecimalType(20, 0), Cast$.MODULE$.apply$default$3(), Cast$.MODULE$.apply$default$4()), (List)Nil$.MODULE$));
                }
                return expr;
            }
        }
        return expr;
    }

    private boolean checkIsWiderType(DecimalType left, DecimalType right, DecimalType wider) {
        DecimalType widerType = SparkShimLoader$.MODULE$.getSparkShims().widerDecimalType(left, right);
        return widerType.equals((Object)wider);
    }

    private final Tuple2 doScale$1(Expression e1, Expression e2) {
        Expression newE2 = this.rescaleCastForOneSide(e2);
        boolean isWiderType = this.checkIsWiderType((DecimalType)e1.dataType(), (DecimalType)newE2.dataType(), (DecimalType)e2.dataType());
        if (isWiderType) {
            return new Tuple2((Object)e1, (Object)newE2);
        }
        return new Tuple2((Object)e1, (Object)e2);
    }

    private DecimalArithmeticUtil$() {
    }
}

