/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kylin.query.mask;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.core.Window;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.kylin.common.KylinConfig;
import org.apache.kylin.common.QueryContext;
import org.apache.kylin.common.exception.ErrorCodeSupplier;
import org.apache.kylin.common.exception.KylinException;
import org.apache.kylin.common.exception.ServerErrorCode;
import org.apache.kylin.guava30.shaded.common.collect.Sets;
import org.apache.kylin.metadata.acl.AclTCRManager;
import org.apache.kylin.metadata.acl.DependentColumn;
import org.apache.kylin.metadata.acl.DependentColumnInfo;
import org.apache.kylin.metadata.model.ColumnDesc;
import org.apache.kylin.metadata.project.NProjectManager;
import org.apache.kylin.query.mask.MaskUtil;
import org.apache.kylin.query.mask.QueryResultMask;
import org.apache.kylin.query.relnode.OlapTableScan;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.catalyst.parser.ParseException;

public class QueryDependentColumnMask
implements QueryResultMask {
    private RelNode rootRelNode;
    private String defaultDatabase;
    private DependentColumnInfo dependentInfo;
    private List<ResultColumnMaskInfo> resultColumnMaskInfos;
    private boolean needMask = false;

    public QueryDependentColumnMask(String project, KylinConfig kylinConfig) {
        this.defaultDatabase = NProjectManager.getInstance((KylinConfig)kylinConfig).getProject(project).getDefaultDatabase();
        QueryContext.AclInfo aclInfo = QueryContext.current().getAclInfo();
        if (aclInfo != null) {
            this.dependentInfo = AclTCRManager.getInstance((KylinConfig)kylinConfig, (String)project).getDependentColumns(aclInfo.getUsername(), aclInfo.getGroups());
        }
    }

    public QueryDependentColumnMask(String defaultDatabase, DependentColumnInfo dependentInfo) {
        this.defaultDatabase = defaultDatabase;
        this.dependentInfo = dependentInfo;
    }

    @Override
    public void doSetRootRelNode(RelNode relNode) {
        this.rootRelNode = relNode;
    }

    @Override
    public void init() {
        assert (this.rootRelNode != null);
        this.resultColumnMaskInfos = this.buildResultColumnMaskInfo(this.getRefCols(this.rootRelNode));
        for (ResultColumnMaskInfo resultColumnMaskInfo : this.resultColumnMaskInfos) {
            if (!resultColumnMaskInfo.needMask()) continue;
            this.needMask = true;
            break;
        }
    }

    @Override
    public Dataset<Row> doMaskResult(Dataset<Row> df) {
        if (this.dependentInfo == null || this.rootRelNode == null || !this.dependentInfo.needMask()) {
            return df;
        }
        if (this.resultColumnMaskInfos == null) {
            this.init();
        }
        if (!this.needMask) {
            return df;
        }
        return this.doResultMaskInternal(df);
    }

    private Dataset<Row> doResultMaskInternal(Dataset<Row> df) {
        Column[] columns = new Column[df.columns().length];
        Dataset<Row> dfWithIndexedCol = MaskUtil.dFToDFWithIndexedColumns(df);
        for (int i = 0; i < dfWithIndexedCol.columns().length; ++i) {
            ResultColumnMaskInfo maskInfo = this.resultColumnMaskInfos.get(i);
            if (!maskInfo.needMask()) {
                columns[i] = dfWithIndexedCol.col(dfWithIndexedCol.columns()[i]);
                continue;
            }
            if (maskInfo.maskAsNull) {
                columns[i] = new Column((Expression)new Literal(null, dfWithIndexedCol.schema().fields()[i].dataType())).as(dfWithIndexedCol.columns()[i]);
                continue;
            }
            try {
                String condExpr = this.maskDependentCondition(dfWithIndexedCol, maskInfo);
                Expression expr = dfWithIndexedCol.sparkSession().sessionState().sqlParser().parseExpression(String.format(Locale.ROOT, "CASE WHEN (%s) THEN `%s` ELSE NULL END", condExpr, dfWithIndexedCol.columns()[i]));
                columns[i] = new Column(expr).as(dfWithIndexedCol.columns()[i]);
                continue;
            }
            catch (ParseException e) {
                throw new KylinException((ErrorCodeSupplier)ServerErrorCode.ACL_DEPENDENT_COLUMN_PARSE_ERROR, (Throwable)e);
            }
        }
        return dfWithIndexedCol.select(columns).toDF(df.columns());
    }

    private String maskDependentCondition(Dataset<Row> dfWithIndexedCol, ResultColumnMaskInfo maskInfo) {
        StringBuilder condExpr = new StringBuilder();
        for (ResultDependentValues dependentValue : maskInfo.dependentValues) {
            String depColumnName = dfWithIndexedCol.columns()[dependentValue.colIdx];
            if (condExpr.length() > 0) {
                condExpr.append(" AND ");
            }
            condExpr.append("(");
            condExpr.append("`").append(depColumnName).append("`");
            condExpr.append(" IN (");
            boolean firstVal = true;
            for (String depValue : dependentValue.values) {
                if (!firstVal) {
                    condExpr.append(",");
                }
                condExpr.append("'").append(depValue).append("'");
                firstVal = false;
            }
            condExpr.append("))");
        }
        return condExpr.toString();
    }

    private List<ResultColumnMaskInfo> buildResultColumnMaskInfo(List<ColumnReferences> resultColRefs) {
        HashMap simpleProjectColumnMap = new HashMap();
        int i = 0;
        for (ColumnReferences ref : resultColRefs) {
            if (ref.isSimpleSingleColumnProject()) {
                simpleProjectColumnMap.put(ref.references.iterator().next(), i);
            }
            ++i;
        }
        LinkedList<ResultColumnMaskInfo> resultMaskInfos = new LinkedList<ResultColumnMaskInfo>();
        for (ColumnReferences resultColRef : resultColRefs) {
            ResultColumnMaskInfo maskInfo = new ResultColumnMaskInfo();
            block2: for (String referenceId : resultColRef.references) {
                Collection dependentColumns = this.dependentInfo.get(referenceId);
                if (dependentColumns.isEmpty()) continue;
                for (DependentColumn dependentColumn : dependentColumns) {
                    Integer depIdx = (Integer)simpleProjectColumnMap.get(dependentColumn.getDependentColumnIdentity());
                    if (depIdx == null) {
                        maskInfo.maskAsNull = true;
                        continue block2;
                    }
                    maskInfo.addDependentValues(new ResultDependentValues(depIdx, dependentColumn.getDependentValues()));
                }
            }
            resultMaskInfos.add(maskInfo);
        }
        return resultMaskInfos;
    }

    private List<ColumnReferences> getRefCols(RelNode relNode) {
        if (relNode instanceof TableScan) {
            return this.getTableColRefs((TableScan)relNode);
        }
        if (relNode instanceof Values) {
            return relNode.getRowType().getFieldList().stream().map(f -> new ColumnReferences()).collect(Collectors.toList());
        }
        if (relNode instanceof Aggregate) {
            return this.getAggregateColRefs((Aggregate)relNode);
        }
        if (relNode instanceof Project) {
            return this.getProjectColRefs((Project)relNode);
        }
        if (relNode instanceof SetOp) {
            return this.getUnionColRefs((SetOp)relNode);
        }
        if (relNode instanceof Window) {
            return this.getWindowColRefs((Window)relNode);
        }
        LinkedList<ColumnReferences> refs = new LinkedList<ColumnReferences>();
        for (RelNode input : relNode.getInputs()) {
            refs.addAll(this.getRefCols(input));
        }
        return refs;
    }

    private List<ColumnReferences> getWindowColRefs(Window window) {
        List<ColumnReferences> inputRefs = this.getRefCols(window.getInput(0));
        LinkedList<ColumnReferences> colRefs = new LinkedList<ColumnReferences>(inputRefs);
        List aggCalls = window.groups.stream().flatMap(group -> group.aggCalls.stream()).collect(Collectors.toList());
        for (RexNode aggCall : aggCalls) {
            ColumnReferences ref = new ColumnReferences();
            for (Integer bit : RelOptUtil.InputFinder.bits((RexNode)aggCall)) {
                if (bit >= inputRefs.size() || inputRefs.get(bit) == null) continue;
                ref = ref.merge(inputRefs.get(bit));
            }
            colRefs.add(ref);
        }
        return colRefs;
    }

    private List<ColumnReferences> getUnionColRefs(SetOp setOp) {
        List<ColumnReferences> refs = new LinkedList<ColumnReferences>();
        for (RelNode input : setOp.getInputs()) {
            List<ColumnReferences> inputRefs = this.getRefCols(input);
            if (refs.isEmpty()) {
                refs = inputRefs;
                continue;
            }
            for (int i = 0; i < inputRefs.size(); ++i) {
                refs.set(i, refs.get(i).merge(inputRefs.get(i)));
            }
        }
        return refs;
    }

    private List<ColumnReferences> getProjectColRefs(Project project) {
        List<ColumnReferences> inputRefs = this.getRefCols(project.getInput(0));
        LinkedList<ColumnReferences> refs = new LinkedList<ColumnReferences>();
        for (RexNode expr : project.getProjects()) {
            ColumnReferences ref = new ColumnReferences();
            for (Integer input : RelOptUtil.InputFinder.bits((RexNode)expr)) {
                ref = ref.merge(inputRefs.get(input));
            }
            if (!(expr instanceof RexInputRef)) {
                ref.hasCalculation = true;
            }
            refs.add(ref);
        }
        return refs;
    }

    private List<ColumnReferences> getAggregateColRefs(Aggregate aggregate) {
        List<ColumnReferences> inputRefs = this.getRefCols(aggregate.getInput(0));
        LinkedList<ColumnReferences> refs = new LinkedList<ColumnReferences>();
        for (Integer groupInputIdx : aggregate.getGroupSet()) {
            refs.add(inputRefs.get(groupInputIdx));
        }
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            ColumnReferences ref = new ColumnReferences();
            for (Integer argInputIdx : aggregateCall.getArgList()) {
                ref = ref.merge(inputRefs.get(argInputIdx));
            }
            ref.hasAggregation = true;
            refs.add(ref);
        }
        return refs;
    }

    private List<ColumnReferences> getTableColRefs(TableScan tableScan) {
        assert (tableScan.getTable().getQualifiedName().size() == 2);
        String dbName = (String)tableScan.getTable().getQualifiedName().get(0);
        String tableName = (String)tableScan.getTable().getQualifiedName().get(1);
        ArrayList<ColumnReferences> refs = new ArrayList<ColumnReferences>();
        for (RelDataTypeField field : tableScan.getRowType().getFieldList()) {
            ColumnDesc columnDesc = (ColumnDesc)((OlapTableScan)tableScan).getOlapTable().getSourceColumns().get(field.getIndex());
            if (columnDesc.isComputedColumn()) {
                refs.add(this.getCCReferences(columnDesc.getComputedColumnExpr()));
                continue;
            }
            refs.add(new ColumnReferences(dbName + "." + tableName + "." + field.getName()));
        }
        return refs;
    }

    private ColumnReferences getCCReferences(String ccExpr) {
        ColumnReferences columnReferences = new ColumnReferences();
        List<SqlIdentifier> ids = MaskUtil.getCCCols(ccExpr);
        for (SqlIdentifier id : ids) {
            if (id.names.size() == 2) {
                columnReferences.addReference(this.defaultDatabase + "." + id.toString());
                continue;
            }
            if (id.names.size() != 3) continue;
            columnReferences.addReference(id.toString());
        }
        columnReferences.hasCalculation = true;
        return columnReferences;
    }

    public List<ResultColumnMaskInfo> getResultColumnMaskInfos() {
        return this.resultColumnMaskInfos;
    }

    public static class ResultDependentValues {
        public int colIdx;
        public Set<String> values;

        public ResultDependentValues(int colIdx, String[] values) {
            this.colIdx = colIdx;
            this.values = Sets.newHashSet((Object[])values);
        }
    }

    public static class ResultColumnMaskInfo {
        public boolean maskAsNull = false;
        public List<ResultDependentValues> dependentValues = new LinkedList<ResultDependentValues>();

        public boolean needMask() {
            return !this.dependentValues.isEmpty() || this.maskAsNull;
        }

        void addDependentValues(ResultDependentValues values) {
            this.dependentValues.add(values);
        }
    }

    static class ColumnReferences {
        boolean hasCalculation = false;
        boolean hasAggregation = false;
        private Set<String> references = new HashSet<String>();

        public ColumnReferences() {
        }

        public ColumnReferences(String column) {
            this.references = Sets.newHashSet((Object[])new String[]{column});
        }

        void addReference(String column) {
            this.references.add(column);
        }

        void addReferences(Collection<String> columns) {
            this.references.addAll(columns);
        }

        ColumnReferences merge(ColumnReferences other) {
            if (other == null) {
                return this;
            }
            ColumnReferences ref = new ColumnReferences();
            ref.addReferences(this.references);
            ref.addReferences(other.references);
            ref.hasCalculation = this.hasCalculation || other.hasCalculation;
            ref.hasAggregation = this.hasAggregation || other.hasAggregation;
            return ref;
        }

        boolean isSimpleSingleColumnProject() {
            return this.references.size() == 1 && !this.hasAggregation && !this.hasCalculation;
        }
    }
}

