package org.apache.hadoop.hive.ql.exec.vector.expressions;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import junit.framework.Assert;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.exec.vector.VectorExtractRow;
import org.apache.hadoop.hive.ql.exec.vector.VectorRandomBatchSource;
import org.apache.hadoop.hive.ql.exec.vector.VectorRandomRowSource;
import org.apache.hadoop.hive.ql.exec.vector.VectorizationContext;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx;
import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIf;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.junit.Ignore;
import org.junit.Test;

/* loaded from: input_file:org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorIfStatement.class */
public class TestVectorIfStatement {

    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorIfStatement$ColumnScalarMode.class */
    public enum ColumnScalarMode {
        COLUMN_COLUMN,
        COLUMN_SCALAR,
        SCALAR_COLUMN,
        SCALAR_SCALAR;

        static final int count = values().length;
    }

    /* loaded from: input_file:org/apache/hadoop/hive/ql/exec/vector/expressions/TestVectorIfStatement$IfStmtTestMode.class */
    public enum IfStmtTestMode {
        ROW_MODE,
        ADAPTOR_WHEN,
        VECTOR_EXPRESSION;

        static final int count = values().length;
    }

    @Test
    public void testBoolean() throws Exception {
        doIfTests(new Random(12882L), "boolean");
    }

    @Test
    public void testInt() throws Exception {
        doIfTests(new Random(12882L), "int");
    }

    @Test
    public void testBigInt() throws Exception {
        doIfTests(new Random(12882L), "bigint");
    }

    @Test
    public void testString() throws Exception {
        doIfTests(new Random(12882L), "string");
    }

    @Test
    public void testTimestamp() throws Exception {
        doIfTests(new Random(12882L), "timestamp");
    }

    @Test
    public void testDate() throws Exception {
        doIfTests(new Random(12882L), "date");
    }

    @Test
    public void testIntervalDayTime() throws Exception {
        doIfTests(new Random(12882L), "interval_day_time", Arrays.asList(IfStmtTestMode.ROW_MODE, IfStmtTestMode.VECTOR_EXPRESSION));
    }

    @Test
    public void testIntervalYearMonth() throws Exception {
        doIfTests(new Random(12882L), "interval_year_month", Arrays.asList(IfStmtTestMode.ROW_MODE, IfStmtTestMode.VECTOR_EXPRESSION));
    }

    @Test
    public void testDouble() throws Exception {
        doIfTests(new Random(12882L), "double");
    }

    @Test
    public void testChar() throws Exception {
        doIfTests(new Random(12882L), "char(10)");
    }

    @Test
    public void testVarchar() throws Exception {
        doIfTests(new Random(12882L), "varchar(15)");
    }

    @Test
    @Ignore("There is no vector expression for binary if expressions and vector adapter doesn't support it")
    public void testBinary() throws Exception {
        doIfTests(new Random(12882L), "binary", Arrays.asList(IfStmtTestMode.ROW_MODE, IfStmtTestMode.VECTOR_EXPRESSION));
    }

    @Test
    public void testDecimalLarge() throws Exception {
        doIfTests(new Random(9300L), "decimal(20,8)");
    }

    @Test
    public void testDecimalSmall() throws Exception {
        doIfTests(new Random(12882L), "decimal(10,4)");
    }

    private void doIfTests(Random random, String str) throws Exception {
        doIfTests(random, str, null);
    }

    private void doIfTests(Random random, String str, List<IfStmtTestMode> list) throws Exception {
        for (ColumnScalarMode columnScalarMode : ColumnScalarMode.values()) {
            doIfTestsWithDiffColumnScalar(random, str, columnScalarMode, list);
        }
    }

    private void doIfTestsWithDiffColumnScalar(Random random, String str, ColumnScalarMode columnScalarMode, List<IfStmtTestMode> list) throws Exception {
        doIfTestsWithDiffColumnScalar(random, str, ColumnScalarMode.COLUMN_COLUMN, list, false, false);
        doIfTestsWithDiffColumnScalar(random, str, ColumnScalarMode.COLUMN_SCALAR, list, false, false);
        doIfTestsWithDiffColumnScalar(random, str, ColumnScalarMode.COLUMN_SCALAR, list, false, true);
        doIfTestsWithDiffColumnScalar(random, str, ColumnScalarMode.SCALAR_COLUMN, list, false, false);
        doIfTestsWithDiffColumnScalar(random, str, ColumnScalarMode.SCALAR_COLUMN, list, true, false);
        doIfTestsWithDiffColumnScalar(random, str, ColumnScalarMode.SCALAR_SCALAR, list, false, false);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void doIfTestsWithDiffColumnScalar(Random random, String str, ColumnScalarMode columnScalarMode, List<IfStmtTestMode> list, boolean z, boolean z2) throws Exception {
        ExprNodeConstantDesc exprNodeColumnDesc;
        ExprNodeConstantDesc exprNodeColumnDesc2;
        System.out.println("*DEBUG* typeName " + str + " columnScalarMode " + columnScalarMode + " isNullScalar1 " + z + " isNullScalar2 " + z2);
        PrimitiveTypeInfo typeInfoFromTypeString = TypeInfoUtils.getTypeInfoFromTypeString(str);
        ArrayList arrayList = new ArrayList();
        arrayList.add("boolean");
        if (columnScalarMode != ColumnScalarMode.SCALAR_SCALAR) {
            arrayList.add(str);
            if (columnScalarMode == ColumnScalarMode.COLUMN_COLUMN) {
                arrayList.add(str);
            }
        }
        VectorRandomRowSource vectorRandomRowSource = new VectorRandomRowSource();
        vectorRandomRowSource.initExplicitSchema(random, arrayList, 0, true);
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add("col0");
        ExprNodeColumnDesc exprNodeColumnDesc3 = new ExprNodeColumnDesc(Boolean.class, "col0", "table", false);
        int i = 1;
        if (columnScalarMode == ColumnScalarMode.COLUMN_COLUMN || columnScalarMode == ColumnScalarMode.COLUMN_SCALAR) {
            i = 1 + 1;
            String str2 = "col1";
            exprNodeColumnDesc = new ExprNodeColumnDesc(typeInfoFromTypeString, str2, "table", false);
            arrayList2.add(str2);
        } else {
            exprNodeColumnDesc = new ExprNodeConstantDesc(typeInfoFromTypeString, z ? null : VectorRandomRowSource.randomPrimitiveObject(random, typeInfoFromTypeString));
        }
        if (columnScalarMode == ColumnScalarMode.COLUMN_COLUMN || columnScalarMode == ColumnScalarMode.SCALAR_COLUMN) {
            int i2 = i;
            int i3 = i + 1;
            String str3 = "col" + i2;
            exprNodeColumnDesc2 = new ExprNodeColumnDesc(typeInfoFromTypeString, str3, "table", false);
            arrayList2.add(str3);
        } else {
            exprNodeColumnDesc2 = new ExprNodeConstantDesc(typeInfoFromTypeString, z2 ? null : VectorRandomRowSource.randomPrimitiveObject(random, typeInfoFromTypeString));
        }
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(exprNodeColumnDesc3);
        arrayList3.add(exprNodeColumnDesc);
        arrayList3.add(exprNodeColumnDesc2);
        VectorizedRowBatchCtx vectorizedRowBatchCtx = new VectorizedRowBatchCtx((String[]) arrayList2.toArray(new String[0]), vectorRandomRowSource.typeInfos(), (int[]) null, 0, new String[]{str});
        Object[][] randomRows = vectorRandomRowSource.randomRows(100000);
        VectorRandomBatchSource createInterestingBatches = VectorRandomBatchSource.createInterestingBatches(random, vectorRandomRowSource, randomRows, null);
        int length = randomRows.length;
        List<IfStmtTestMode> asList = list == null ? Arrays.asList(IfStmtTestMode.values()) : list;
        Object[] objArr = new Object[asList.size()];
        for (int i4 = 0; i4 < asList.size(); i4++) {
            Object[] objArr2 = new Object[length];
            objArr[i4] = objArr2;
            IfStmtTestMode ifStmtTestMode = asList.get(i4);
            switch (ifStmtTestMode) {
                case ROW_MODE:
                    doRowIfTest(typeInfoFromTypeString, arrayList2, arrayList3, randomRows, vectorRandomRowSource.rowStructObjectInspector(), objArr2);
                    break;
                case ADAPTOR_WHEN:
                case VECTOR_EXPRESSION:
                    doVectorIfTest(typeInfoFromTypeString, arrayList2, vectorRandomRowSource.typeInfos(), arrayList3, ifStmtTestMode, columnScalarMode, createInterestingBatches, vectorizedRowBatchCtx, objArr2);
                    break;
                default:
                    throw new RuntimeException("Unexpected IF statement test mode " + ifStmtTestMode);
            }
        }
        for (int i5 = 0; i5 < length; i5++) {
            Object[] objArr3 = objArr[0][i5];
            for (int i6 = 1; i6 < asList.size(); i6++) {
                Object[] objArr4 = objArr[i6][i5];
                if (objArr3 == 0 || objArr4 == 0) {
                    if (objArr3 != 0 || objArr4 != 0) {
                        Assert.fail("Row " + i5 + " " + asList.get(i6) + " " + columnScalarMode + " result is NULL " + (objArr4 == 0) + " does not match row-mode expected result is NULL " + (objArr3 == 0));
                    }
                } else if (!objArr3.equals(objArr4)) {
                    Assert.fail("Row " + i5 + " " + IfStmtTestMode.values()[i6] + " " + columnScalarMode + " result " + objArr4.toString() + " (" + objArr4.getClass().getSimpleName() + ") does not match row-mode expected result " + objArr3.toString() + " (" + objArr3.getClass().getSimpleName() + ")");
                }
            }
        }
    }

    private void doRowIfTest(TypeInfo typeInfo, List<String> list, List<ExprNodeDesc> list2, Object[][] objArr, ObjectInspector objectInspector, Object[] objArr2) throws Exception {
        ExprNodeEvaluator exprNodeEvaluator = ExprNodeEvaluatorFactory.get(new ExprNodeGenericFuncDesc(typeInfo, new GenericUDFIf(), list2));
        exprNodeEvaluator.initialize(objectInspector);
        int length = objArr.length;
        for (int i = 0; i < length; i++) {
            objArr2[i] = exprNodeEvaluator.evaluate(objArr[i]);
        }
    }

    private void extractResultObjects(VectorizedRowBatch vectorizedRowBatch, int i, VectorExtractRow vectorExtractRow, Object[] objArr, Object[] objArr2, int i2, TypeInfo typeInfo) {
        ObjectInspector standardWritableObjectInspectorFromTypeInfo = TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(typeInfo);
        boolean z = vectorizedRowBatch.selectedInUse;
        int[] iArr = vectorizedRowBatch.selected;
        for (int i3 = 0; i3 < vectorizedRowBatch.size; i3++) {
            int i4 = i;
            i++;
            objArr2[i4] = ObjectInspectorUtils.copyToStandardObject(vectorExtractRow.extractRowColumn(vectorizedRowBatch, z ? iArr[i3] : i3, i2), standardWritableObjectInspectorFromTypeInfo, ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE);
        }
    }

    private void doVectorIfTest(TypeInfo typeInfo, List<String> list, TypeInfo[] typeInfoArr, List<ExprNodeDesc> list2, IfStmtTestMode ifStmtTestMode, ColumnScalarMode columnScalarMode, VectorRandomBatchSource vectorRandomBatchSource, VectorizedRowBatchCtx vectorizedRowBatchCtx, Object[] objArr) throws Exception {
        GenericUDFIf genericUDFWhen;
        switch (ifStmtTestMode) {
            case ADAPTOR_WHEN:
                genericUDFWhen = new GenericUDFWhen();
                break;
            case VECTOR_EXPRESSION:
                genericUDFWhen = new GenericUDFIf();
                break;
            default:
                throw new RuntimeException("Unexpected IF statement test mode " + ifStmtTestMode);
        }
        VectorExpression vectorExpression = new VectorizationContext("name", list, new HiveConf()).getVectorExpression(new ExprNodeGenericFuncDesc(typeInfo, genericUDFWhen, list2));
        VectorizedRowBatch createVectorizedRowBatch = vectorizedRowBatchCtx.createVectorizedRowBatch();
        VectorExtractRow vectorExtractRow = new VectorExtractRow();
        ArrayList arrayList = new ArrayList(vectorRandomBatchSource.getRowSource().typeNames());
        arrayList.add(typeInfo.getTypeName());
        vectorExtractRow.init(arrayList);
        Object[] objArr2 = new Object[arrayList.size()];
        System.out.println("*DEBUG* typeInfo " + typeInfo.toString() + " ifStmtTestMode " + ifStmtTestMode + " columnScalarMode " + columnScalarMode + " vectorExpression " + vectorExpression.getClass().getSimpleName());
        vectorRandomBatchSource.resetBatchIteration();
        int i = 0;
        while (true) {
            int i2 = i;
            if (!vectorRandomBatchSource.fillNextBatch(createVectorizedRowBatch)) {
                return;
            }
            vectorExpression.evaluate(createVectorizedRowBatch);
            extractResultObjects(createVectorizedRowBatch, i2, vectorExtractRow, objArr2, objArr, vectorExpression.getOutputColumn(), typeInfo);
            i = i2 + createVectorizedRowBatch.size;
        }
    }
}
