/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.parser;

import java.util.HashMap;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.ConstIdentifier;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Expression;
import org.apache.sysds.parser.FunctionCallIdentifier;
import org.apache.sysds.parser.Identifier;
import org.apache.sysds.parser.ParseInfo;
import org.apache.sysds.parser.StringIdentifier;
import org.apache.sysds.parser.VariableSet;

public class BinaryExpression
extends Expression {
    private Expression _left;
    private Expression _right;
    private Expression.BinaryOp _opcode;

    @Override
    public Expression rewriteExpression(String prefix) {
        BinaryExpression newExpr = new BinaryExpression(this._opcode, this);
        newExpr.setLeft(this._left.rewriteExpression(prefix));
        newExpr.setRight(this._right.rewriteExpression(prefix));
        return newExpr;
    }

    public BinaryExpression(Expression.BinaryOp bop) {
        this._opcode = bop;
        this.setFilename("MAIN SCRIPT");
        this.setBeginLine(0);
        this.setBeginColumn(0);
        this.setEndLine(0);
        this.setEndColumn(0);
        this.setText(null);
    }

    public BinaryExpression(Expression.BinaryOp bop, ParseInfo parseInfo) {
        this._opcode = bop;
        this.setParseInfo(parseInfo);
    }

    public Expression.BinaryOp getOpCode() {
        return this._opcode;
    }

    public void setLeft(Expression l) {
        this._left = l;
        if (this._left != null) {
            this.setParseInfo(this._left);
        }
    }

    public void setRight(Expression r) {
        this._right = r;
        if (this._right != null) {
            this.setParseInfo(this._right);
        }
    }

    public Expression getLeft() {
        return this._left;
    }

    public Expression getRight() {
        return this._right;
    }

    @Override
    public void validateExpression(HashMap<String, DataIdentifier> ids, HashMap<String, ConstIdentifier> constVars, boolean conditional) {
        if (this._left instanceof FunctionCallIdentifier || this._right instanceof FunctionCallIdentifier) {
            this.raiseValidateError("User-defined function calls not supported in binary expressions.", false, "Unsupported Expression");
        }
        this._left.validateExpression(ids, constVars, conditional);
        this._right.validateExpression(ids, constVars, conditional);
        if (!conditional) {
            if (this._left instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)this._left).getName())) {
                this._left = constVars.get(((DataIdentifier)this._left).getName());
            }
            if (this._right instanceof DataIdentifier && constVars.containsKey(((DataIdentifier)this._right).getName())) {
                this._right = constVars.get(((DataIdentifier)this._right).getName());
            }
        }
        String outputName = BinaryExpression.getTempName();
        DataIdentifier output = new DataIdentifier(outputName);
        output.setParseInfo(this);
        output.setDataType(BinaryExpression.computeDataType(this.getLeft(), this.getRight(), true));
        Types.ValueType resultVT = BinaryExpression.computeValueType(this.getLeft(), this.getRight(), true);
        if (this.getOpCode() == Expression.BinaryOp.POW || this.getOpCode() == Expression.BinaryOp.DIV) {
            resultVT = Types.ValueType.FP64;
        }
        output.setValueType(resultVT);
        this.checkAndSetDimensions(output, conditional);
        if (this.getOpCode() == Expression.BinaryOp.MATMULT) {
            if (this.getLeft().getOutput().getDataType() != Types.DataType.MATRIX || this.getRight().getOutput().getDataType() != Types.DataType.MATRIX) {
                // empty if block
            }
            if (this.getLeft().getOutput().getDim2() != -1L && this.getRight().getOutput().getDim1() != -1L && this.getLeft().getOutput().getDim2() != this.getRight().getOutput().getDim1()) {
                this.raiseValidateError("invalid dimensions for matrix multiplication (k1=" + this.getLeft().getOutput().getDim2() + ", k2=" + this.getRight().getOutput().getDim1() + ")", conditional, "Invalid Parameters");
            }
            output.setDimensions(this.getLeft().getOutput().getDim1(), this.getRight().getOutput().getDim2());
        }
        this.setOutput(output);
    }

    private void checkAndSetDimensions(DataIdentifier output, boolean conditional) {
        Identifier left = this.getLeft().getOutput();
        Identifier right = this.getRight().getOutput();
        Identifier pivot = null;
        Identifier aux = null;
        if (left.getDataType() == Types.DataType.MATRIX) {
            pivot = left;
            if (right.getDataType() == Types.DataType.MATRIX) {
                aux = right;
            }
        } else if (right.getDataType() == Types.DataType.MATRIX) {
            pivot = right;
        }
        if (pivot != null && aux != null && BinaryExpression.isSameDimensionBinaryOp(this.getOpCode()) && pivot.dimsKnown() && aux.dimsKnown() && (pivot.getDim1() != aux.getDim1() && aux.getDim1() > 1L || pivot.getDim2() != aux.getDim2() && aux.getDim2() > 1L)) {
            this.raiseValidateError("Mismatch in dimensions for operation '" + this.getText() + "'. " + pivot + " is " + pivot.getDim1() + "x" + pivot.getDim2() + " and " + aux + " is " + aux.getDim1() + "x" + aux.getDim2() + ".", conditional);
        }
        if (pivot != null) {
            output.setDimensions(pivot.getDim1(), pivot.getDim2());
        }
    }

    public String toString() {
        Object leftString = this._left instanceof StringIdentifier ? "\"" + this._left.toString() + "\"" : this._left.toString();
        Object rightString = this._right instanceof StringIdentifier ? "\"" + this._right.toString() + "\"" : this._right.toString();
        return "(" + (String)leftString + " " + this._opcode.toString() + " " + (String)rightString + ")";
    }

    @Override
    public VariableSet variablesRead() {
        VariableSet result = new VariableSet();
        result.addVariables(this._left.variablesRead());
        result.addVariables(this._right.variablesRead());
        return result;
    }

    @Override
    public VariableSet variablesUpdated() {
        VariableSet result = new VariableSet();
        result.addVariables(this._left.variablesUpdated());
        result.addVariables(this._right.variablesUpdated());
        return result;
    }

    public static boolean isSameDimensionBinaryOp(Expression.BinaryOp op) {
        return op == Expression.BinaryOp.PLUS || op == Expression.BinaryOp.MINUS || op == Expression.BinaryOp.MULT || op == Expression.BinaryOp.DIV || op == Expression.BinaryOp.MODULUS || op == Expression.BinaryOp.INTDIV || op == Expression.BinaryOp.POW;
    }
}

