/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec.tez;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.exec.SerializationUtilities;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.ExprNodeConstantDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeDynamicListDesc;
import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
import org.apache.hadoop.hive.serde2.AbstractSerDe;
import org.apache.hadoop.hive.serde2.Deserializer;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveWritableObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputInitializerEvent;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DynamicPartitionPruner {
    private static final Logger LOG = LoggerFactory.getLogger(DynamicPartitionPruner.class);
    private InputInitializerContext context;
    private MapWork work;
    private JobConf jobConf;
    private final Map<String, List<SourceInfo>> sourceInfoMap = new HashMap<String, List<SourceInfo>>();
    private final BytesWritable writable = new BytesWritable();
    private final BlockingQueue<Object> queue = new LinkedBlockingQueue<Object>();
    private final BlockingQueue<String> finishedVertices = new LinkedBlockingQueue<String>();
    private static final Object VERTEX_FINISH_TOKEN = new Object();
    private final Set<String> sourcesWaitingForEvents = new HashSet<String>();
    private final Map<String, MutableInt> numExpectedEventsPerSource = new HashMap<String, MutableInt>();
    private final Map<String, MutableInt> numEventsSeenPerSource = new HashMap<String, MutableInt>();
    private int sourceInfoCount = 0;
    private int totalEventCount = 0;

    public void prune() throws SerDeException, IOException, InterruptedException, HiveException {
        if (this.sourcesWaitingForEvents.isEmpty()) {
            return;
        }
        Set<VertexState> states = Collections.singleton(VertexState.SUCCEEDED);
        for (String source : this.sourcesWaitingForEvents) {
            this.context.registerForVertexStateUpdates(source, states);
        }
        LOG.info("Waiting for events ({} sources) ...", (Object)this.sourceInfoCount);
        this.processEvents();
        this.prunePartitions();
        LOG.info("Ok to proceed.");
    }

    private void clear() {
        this.sourceInfoMap.clear();
        this.sourceInfoCount = 0;
    }

    public void initialize(InputInitializerContext context, MapWork work, JobConf jobConf) throws SerDeException {
        this.clear();
        this.context = context;
        this.work = work;
        this.jobConf = jobConf;
        HashMap<String, SourceInfo> columnMap = new HashMap<String, SourceInfo>();
        Set<String> sources = work.getEventSourceTableDescMap().keySet();
        this.sourcesWaitingForEvents.addAll(sources);
        for (String s : sources) {
            this.numExpectedEventsPerSource.put(s, new MutableInt(0));
            this.numEventsSeenPerSource.put(s, new MutableInt(0));
            List<TableDesc> tables = work.getEventSourceTableDescMap().get(s);
            List<String> columnNames = work.getEventSourceColumnNameMap().get(s);
            List<String> columnTypes = work.getEventSourceColumnTypeMap().get(s);
            List<ExprNodeDesc> partKeyExprs = work.getEventSourcePartKeyExprMap().get(s);
            List<ExprNodeDesc> predicates = work.getEventSourcePredicateExprMap().get(s);
            Iterator<String> cit = columnNames.iterator();
            Iterator<String> typit = columnTypes.iterator();
            Iterator<ExprNodeDesc> pit = partKeyExprs.iterator();
            Iterator<ExprNodeDesc> predit = predicates.iterator();
            for (TableDesc t : tables) {
                this.numExpectedEventsPerSource.get(s).decrement();
                ++this.sourceInfoCount;
                String columnName = cit.next();
                String columnType = typit.next();
                ExprNodeDesc partKeyExpr = pit.next();
                ExprNodeDesc predicate = predit.next();
                SourceInfo si = this.createSourceInfo(t, partKeyExpr, predicate, columnName, columnType, jobConf);
                if (!this.sourceInfoMap.containsKey(s)) {
                    this.sourceInfoMap.put(s, new ArrayList());
                }
                List<SourceInfo> sis = this.sourceInfoMap.get(s);
                sis.add(si);
                if (columnMap.containsKey(columnName)) {
                    si.values = ((SourceInfo)columnMap.get((Object)columnName)).values;
                    si.skipPruning = ((SourceInfo)columnMap.get((Object)columnName)).skipPruning;
                }
                columnMap.put(columnName, si);
            }
        }
    }

    private void prunePartitions() throws HiveException {
        int expectedEvents = 0;
        LinkedList<ExprNodeDesc> prunerExprs = new LinkedList<ExprNodeDesc>();
        for (Map.Entry<String, List<SourceInfo>> entry : this.sourceInfoMap.entrySet()) {
            String source = entry.getKey();
            for (SourceInfo si : entry.getValue()) {
                int taskNum = this.context.getVertexNumTasks(source);
                LOG.info("Expecting {} events for vertex {}, for column {}", new Object[]{taskNum, source, si.columnName});
                expectedEvents += taskNum;
                ExprNodeDesc prunerExpr = this.prunePartitionSingleSource(this.jobConf, source, si);
                if (prunerExpr == null) continue;
                prunerExprs.add(prunerExpr);
            }
        }
        if (prunerExprs.size() != 0) {
            ExprNodeGenericFuncDesc prunerExpr = prunerExprs.size() == 1 ? (ExprNodeGenericFuncDesc)prunerExprs.iterator().next() : new ExprNodeGenericFuncDesc((TypeInfo)TypeInfoFactory.booleanTypeInfo, (GenericUDF)new GenericUDFOPAnd(), "and", prunerExprs);
            this.jobConf.set("hive.io.pruning.filter", SerializationUtilities.serializeExpression(prunerExpr));
        }
        if (expectedEvents != this.totalEventCount) {
            LOG.error("Expecting: {} events, received: {}", (Object)expectedEvents, (Object)this.totalEventCount);
            throw new HiveException("Incorrect event count in dynamic partition pruning");
        }
    }

    @VisibleForTesting
    protected ExprNodeDesc prunePartitionSingleSource(JobConf jobConf, String source, SourceInfo si) throws HiveException {
        if (si.skipPruning.get()) {
            LOG.info("Skip pruning on {}, column {}", (Object)source, (Object)si.columnName);
            return null;
        }
        Set<Object> values = si.values;
        String columnName = si.columnName;
        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder("Pruning ");
            sb.append(columnName);
            sb.append(" with ");
            for (Object value : values) {
                sb.append(value == null ? null : value.toString());
                sb.append(", ");
            }
            LOG.debug(sb.toString());
        }
        AbstractPrimitiveWritableObjectInspector oi = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector((PrimitiveTypeInfo)TypeInfoFactory.getPrimitiveTypeInfo((String)si.columnType));
        if (si.predicate != null) {
            List<ExprNodeConstantDesc> dynArgs = values.stream().map(arg_0 -> DynamicPartitionPruner.lambda$prunePartitionSingleSource$0((PrimitiveObjectInspector)oi, arg_0)).collect(Collectors.toList());
            ExprNodeDesc clone = si.predicate.clone();
            this.replaceDynamicLists(clone, dynArgs);
            return clone;
        }
        ObjectInspectorConverters.Converter converter = ObjectInspectorConverters.getConverter((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector, (ObjectInspector)oi);
        StandardStructObjectInspector soi = ObjectInspectorFactory.getStandardStructObjectInspector(Collections.singletonList(columnName), Collections.singletonList(oi));
        ExprNodeEvaluator eval = ExprNodeEvaluatorFactory.get(si.partKey);
        eval.initialize((ObjectInspector)soi);
        this.applyFilterToPartitions(converter, eval, columnName, values, si.mustKeepOnePartition);
        return null;
    }

    private void applyFilterToPartitions(ObjectInspectorConverters.Converter converter, ExprNodeEvaluator eval, String columnName, Set<Object> values, boolean mustKeepOnePartition) throws HiveException {
        Object[] row = new Object[1];
        Iterator<Path> it = this.work.getPathToPartitionInfo().keySet().iterator();
        while (it.hasNext()) {
            Path p = it.next();
            PartitionDesc desc = this.work.getPathToPartitionInfo().get(p);
            LinkedHashMap<String, String> spec = desc.getPartSpec();
            if (spec == null) {
                throw new IllegalStateException("No partition spec found in dynamic pruning");
            }
            String partValueString = (String)spec.get(columnName);
            if (partValueString == null) {
                throw new IllegalStateException("Could not find partition value for column: " + columnName);
            }
            Object partValue = converter.convert((Object)partValueString);
            LOG.debug("Converted partition value: {} original ({})", partValue, (Object)partValueString);
            row[0] = partValue;
            partValue = eval.evaluate(row);
            LOG.debug("part key expr applied: {}", partValue);
            if (values.contains(partValue) || mustKeepOnePartition && this.work.getPathToPartitionInfo().size() <= 1) continue;
            LOG.info("Pruning path: {}", (Object)p);
            it.remove();
            this.work.removePathToAlias(p);
        }
    }

    @VisibleForTesting
    protected SourceInfo createSourceInfo(TableDesc t, ExprNodeDesc partKeyExpr, ExprNodeDesc predicate, String columnName, String columnType, JobConf jobConf) throws SerDeException {
        return new SourceInfo(t, partKeyExpr, predicate, columnName, columnType, jobConf);
    }

    private void processEvents() throws SerDeException, IOException, InterruptedException {
        int eventCount = 0;
        while (true) {
            Object element;
            if ((element = this.queue.take()) == VERTEX_FINISH_TOKEN) {
                String updatedSource = (String)this.finishedVertices.poll();
                this.calculateFinishCondition(updatedSource);
                if (!this.checkForSourceCompletion(updatedSource)) continue;
                break;
            }
            InputInitializerEvent event = (InputInitializerEvent)element;
            ByteBuffer payload = event.getUserPayload();
            this.numEventsSeenPerSource.computeIfAbsent(event.getSourceVertexName(), vn -> new MutableInt(0)).increment();
            ++this.totalEventCount;
            LOG.info("Input event ({} -> {} {}), event payload size: {}", new Object[]{event.getSourceVertexName(), event.getTargetVertexName(), event.getTargetInputName(), payload.limit() - payload.position()});
            this.processPayload(event.getUserPayload(), event.getSourceVertexName());
            ++eventCount;
            if (this.checkForSourceCompletion(event.getSourceVertexName())) break;
        }
        LOG.info("Received events: {}", (Object)eventCount);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @VisibleForTesting
    protected String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, IOException {
        try (DataInputStream in = new DataInputStream(new ByteBufferBackedInputStream(payload));){
            String columnName = in.readUTF();
            LOG.info("Source of event: " + sourceName);
            List<SourceInfo> infos = this.sourceInfoMap.get(sourceName);
            if (infos == null) {
                throw new IllegalStateException("no source info for event source: " + sourceName);
            }
            SourceInfo info = null;
            for (SourceInfo si : infos) {
                if (!columnName.equals(si.columnName)) continue;
                info = si;
                break;
            }
            if (info == null) {
                throw new IllegalStateException("no source info for column: " + columnName);
            }
            if (info.skipPruning.get()) {
            } else {
                boolean skip = in.readBoolean();
                if (skip) {
                    info.skipPruning.set(true);
                } else {
                    int partitionCount = 0;
                    while (payload.hasRemaining()) {
                        this.writable.readFields((DataInput)in);
                        Object row = info.deserializer.deserialize((Writable)this.writable);
                        Object value = info.soi.getStructFieldData(row, info.field);
                        value = ObjectInspectorUtils.copyToStandardObject((Object)value, (ObjectInspector)info.fieldInspector);
                        LOG.debug("Adding: {} to list of required partitions", value);
                        info.values.add(value);
                        ++partitionCount;
                    }
                    LOG.info("Received {} partitions (source: {}, column: {})", new Object[]{partitionCount, sourceName, columnName});
                }
            }
        }
        return sourceName;
    }

    public void addEvent(InputInitializerEvent event) {
        if (!this.queue.offer(event)) {
            throw new IllegalStateException("Queue full");
        }
    }

    private void calculateFinishCondition(String sourceName) {
        MutableInt prevVal = this.numExpectedEventsPerSource.get(sourceName);
        int prevValInt = prevVal.intValue();
        Preconditions.checkState((prevValInt < 0 ? 1 : 0) != 0, (Object)("Invalid value for numExpectedEvents for source: " + sourceName + ", oldVal=" + prevValInt));
        prevVal.setValue(-1 * prevValInt * this.context.getVertexNumTasks(sourceName));
    }

    public void processVertex(String name) {
        LOG.info("Vertex succeeded: {}", (Object)name);
        this.finishedVertices.add(name);
        this.queue.offer(VERTEX_FINISH_TOKEN);
    }

    private boolean checkForSourceCompletion(String name) {
        int expectedEvents = this.numExpectedEventsPerSource.get(name).getValue();
        if (expectedEvents < 0) {
            return false;
        }
        int processedEvents = this.numEventsSeenPerSource.get(name).getValue();
        if (processedEvents == expectedEvents) {
            this.sourcesWaitingForEvents.remove(name);
            if (this.sourcesWaitingForEvents.isEmpty()) {
                return true;
            }
            LOG.info("Waiting for {} sources.", (Object)this.sourcesWaitingForEvents.size());
            return false;
        }
        if (processedEvents > expectedEvents) {
            throw new IllegalStateException("Received too many events for " + name + ", Expected=" + expectedEvents + ", Received=" + processedEvents);
        }
        return false;
    }

    private void replaceDynamicLists(ExprNodeDesc node, Collection<ExprNodeConstantDesc> dynArgs) {
        List<ExprNodeDesc> children = node.getChildren();
        if (children != null && !children.isEmpty()) {
            ListIterator<ExprNodeDesc> iterator = node.getChildren().listIterator();
            while (iterator.hasNext()) {
                ExprNodeDesc child = iterator.next();
                if (child instanceof ExprNodeDynamicListDesc) {
                    iterator.remove();
                    dynArgs.forEach(iterator::add);
                    continue;
                }
                this.replaceDynamicLists(child, dynArgs);
            }
        }
    }

    private static /* synthetic */ ExprNodeConstantDesc lambda$prunePartitionSingleSource$0(PrimitiveObjectInspector oi, Object v) {
        return new ExprNodeConstantDesc(oi.getPrimitiveJavaObject(v));
    }

    @VisibleForTesting
    static class SourceInfo {
        public final ExprNodeDesc partKey;
        public final ExprNodeDesc predicate;
        public final Deserializer deserializer;
        public final StructObjectInspector soi;
        public final StructField field;
        public final ObjectInspector fieldInspector;
        public Set<Object> values = new HashSet<Object>();
        public AtomicBoolean skipPruning = new AtomicBoolean();
        public final String columnName;
        public final String columnType;
        private boolean mustKeepOnePartition;

        @VisibleForTesting
        SourceInfo(TableDesc table, ExprNodeDesc partKey, ExprNodeDesc predicate, String columnName, String columnType, JobConf jobConf, Object forTesting) {
            this.partKey = partKey;
            this.predicate = predicate;
            this.columnName = columnName;
            this.columnType = columnType;
            this.deserializer = null;
            this.soi = null;
            this.field = null;
            this.fieldInspector = null;
        }

        public SourceInfo(TableDesc table, ExprNodeDesc partKey, ExprNodeDesc predicate, String columnName, String columnType, JobConf jobConf) throws SerDeException {
            this.skipPruning.set(false);
            this.partKey = partKey;
            this.predicate = predicate;
            this.columnName = columnName;
            this.columnType = columnType;
            this.mustKeepOnePartition = jobConf.getBoolean("ENSURE_OPERATORS_EXECUTED", false);
            AbstractSerDe serDe = (AbstractSerDe)ReflectionUtils.newInstance(table.getSerDeClass(), null);
            serDe.initialize((Configuration)jobConf, table.getProperties(), null);
            this.deserializer = serDe;
            ObjectInspector inspector = this.deserializer.getObjectInspector();
            LOG.debug("Type of obj insp: {}", (Object)inspector.getTypeName());
            this.soi = (StructObjectInspector)inspector;
            List fields = this.soi.getAllStructFieldRefs();
            if (fields.size() > 1) {
                LOG.error("expecting single field in input");
            }
            this.field = (StructField)fields.get(0);
            this.fieldInspector = ObjectInspectorUtils.getStandardObjectInspector((ObjectInspector)this.field.getFieldObjectInspector());
        }
    }

    private static class ByteBufferBackedInputStream
    extends InputStream {
        ByteBuffer buf;

        public ByteBufferBackedInputStream(ByteBuffer buf) {
            this.buf = buf;
        }

        @Override
        public int read() throws IOException {
            if (!this.buf.hasRemaining()) {
                return -1;
            }
            return this.buf.get() & 0xFF;
        }

        @Override
        public int read(byte[] bytes, int off, int len) throws IOException {
            if (!this.buf.hasRemaining()) {
                return -1;
            }
            len = Math.min(len, this.buf.remaining());
            this.buf.get(bytes, off, len);
            return len;
        }
    }
}

