/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec.tez;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.protobuf.ByteString;
import java.io.DataOutput;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.ql.exec.tez.CustomEdgeConfiguration;
import org.apache.hadoop.hive.ql.exec.tez.CustomPartitionEdge;
import org.apache.hadoop.hive.ql.io.HiveInputFormat;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.serializer.SerializationFactory;
import org.apache.hadoop.mapred.FileSplit;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.split.TezGroupedSplitsInputFormat;
import org.apache.hadoop.mapred.split.TezMapredSplitsGrouper;
import org.apache.tez.dag.api.EdgeManagerDescriptor;
import org.apache.tez.dag.api.EdgeProperty;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.dag.api.VertexLocationHint;
import org.apache.tez.dag.api.VertexManagerPlugin;
import org.apache.tez.dag.api.VertexManagerPluginContext;
import org.apache.tez.mapreduce.hadoop.MRHelpers;
import org.apache.tez.mapreduce.protos.MRRuntimeProtos;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.events.RootInputConfigureVertexTasksEvent;
import org.apache.tez.runtime.api.events.RootInputDataInformationEvent;
import org.apache.tez.runtime.api.events.RootInputUpdatePayloadEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;

public class CustomPartitionVertex
implements VertexManagerPlugin {
    private static final Log LOG = LogFactory.getLog((String)CustomPartitionVertex.class.getName());
    public static final String GROUP_SPLITS = "hive.enable.custom.grouped.splits";
    VertexManagerPluginContext context;
    private Multimap<Integer, Integer> bucketToTaskMap = HashMultimap.create();
    private Multimap<Integer, InputSplit> bucketToInitialSplitMap = ArrayListMultimap.create();
    private RootInputConfigureVertexTasksEvent configureVertexTaskEvent;
    private List<RootInputDataInformationEvent> dataInformationEvents;
    private Map<Path, List<FileSplit>> pathFileSplitsMap = new TreeMap<Path, List<FileSplit>>();
    private int numBuckets = -1;
    private Configuration conf = null;
    private boolean rootVertexInitialized = false;
    Multimap<Integer, InputSplit> bucketToGroupedSplitMap;
    private Map<Integer, Integer> bucketToNumTaskMap = new HashMap<Integer, Integer>();

    public void initialize(VertexManagerPluginContext context) {
        this.context = context;
        ByteBuffer byteBuf = ByteBuffer.wrap(context.getUserPayload());
        this.numBuckets = byteBuf.getInt();
    }

    public void onVertexStarted(Map<String, List<Integer>> completions) {
        int numTasks = this.context.getVertexNumTasks(this.context.getVertexName());
        ArrayList<Integer> scheduledTasks = new ArrayList<Integer>(numTasks);
        for (int i = 0; i < numTasks; ++i) {
            scheduledTasks.add(new Integer(i));
        }
        this.context.scheduleVertexTasks(scheduledTasks);
    }

    public void onSourceTaskCompleted(String srcVertexName, Integer attemptId) {
    }

    public void onVertexManagerEventReceived(VertexManagerEvent vmEvent) {
    }

    public void onRootVertexInitialized(String inputName, InputDescriptor inputDescriptor, List<Event> events) {
        Preconditions.checkState((!this.rootVertexInitialized ? 1 : 0) != 0);
        this.rootVertexInitialized = true;
        try {
            MRRuntimeProtos.MRInputUserPayloadProto protoPayload = MRHelpers.parseMRInputPayload((byte[])inputDescriptor.getUserPayload());
            this.conf = MRHelpers.createConfFromByteString((ByteString)protoPayload.getConfigurationBytes());
            if (this.conf.getBoolean(GROUP_SPLITS, true)) {
                this.conf.set("mapred.input.format.class", TezGroupedSplitsInputFormat.class.getName());
                MRRuntimeProtos.MRInputUserPayloadProto updatedPayload = MRRuntimeProtos.MRInputUserPayloadProto.newBuilder((MRRuntimeProtos.MRInputUserPayloadProto)protoPayload).setConfigurationBytes(MRHelpers.createByteStringFromConf((Configuration)this.conf)).build();
                inputDescriptor.setUserPayload(updatedPayload.toByteArray());
            }
        }
        catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
        boolean dataInformationEventSeen = false;
        for (Event event : events) {
            FileSplit fileSplit;
            if (event instanceof RootInputConfigureVertexTasksEvent) {
                RootInputConfigureVertexTasksEvent cEvent;
                Preconditions.checkState((!dataInformationEventSeen ? 1 : 0) != 0);
                Preconditions.checkState((this.context.getVertexNumTasks(this.context.getVertexName()) == -1 ? 1 : 0) != 0, (Object)"Parallelism for the vertex should be set to -1 if the InputInitializer is setting parallelism");
                this.configureVertexTaskEvent = cEvent = (RootInputConfigureVertexTasksEvent)event;
                this.dataInformationEvents = Lists.newArrayListWithCapacity((int)this.configureVertexTaskEvent.getNumTasks());
            }
            if (event instanceof RootInputUpdatePayloadEvent) {
                Preconditions.checkState((boolean)false);
                continue;
            }
            if (!(event instanceof RootInputDataInformationEvent)) continue;
            dataInformationEventSeen = true;
            RootInputDataInformationEvent diEvent = (RootInputDataInformationEvent)event;
            this.dataInformationEvents.add(diEvent);
            try {
                fileSplit = this.getFileSplitFromEvent(diEvent);
            }
            catch (IOException e) {
                throw new RuntimeException("Failed to get file split for event: " + diEvent);
            }
            List<FileSplit> fsList = this.pathFileSplitsMap.get(fileSplit.getPath());
            if (fsList == null) {
                fsList = new ArrayList<FileSplit>();
                this.pathFileSplitsMap.put(fileSplit.getPath(), fsList);
            }
            fsList.add(fileSplit);
        }
        this.setBucketNumForPath(this.pathFileSplitsMap);
        try {
            this.groupSplits();
            this.processAllEvents(inputName);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void processAllEvents(String inputName) throws IOException {
        LinkedList finalSplits = Lists.newLinkedList();
        int taskCount = 0;
        for (Map.Entry entry : this.bucketToGroupedSplitMap.asMap().entrySet()) {
            int bucketNum = (Integer)entry.getKey();
            Collection initialSplits = (Collection)entry.getValue();
            finalSplits.addAll(initialSplits);
            for (int i = 0; i < initialSplits.size(); ++i) {
                this.bucketToTaskMap.put((Object)bucketNum, (Object)taskCount);
                ++taskCount;
            }
        }
        EdgeManagerDescriptor hiveEdgeManagerDesc = new EdgeManagerDescriptor(CustomPartitionEdge.class.getName());
        byte[] payload = this.getBytePayload(this.bucketToTaskMap);
        hiveEdgeManagerDesc.setUserPayload(payload);
        HashMap emMap = Maps.newHashMap();
        for (Map.Entry edgeEntry : this.context.getInputVertexEdgeProperties().entrySet()) {
            if (((EdgeProperty)edgeEntry.getValue()).getDataMovementType() != EdgeProperty.DataMovementType.CUSTOM || !((EdgeProperty)edgeEntry.getValue()).getEdgeManagerDescriptor().getClassName().equals(CustomPartitionEdge.class.getName())) continue;
            emMap.put(edgeEntry.getKey(), hiveEdgeManagerDesc);
        }
        LOG.info((Object)("Task count is " + taskCount));
        ArrayList taskEvents = Lists.newArrayListWithCapacity((int)finalSplits.size());
        int count = 0;
        for (InputSplit inputSplit : finalSplits) {
            MRRuntimeProtos.MRSplitProto serializedSplit = MRHelpers.createSplitProto((InputSplit)inputSplit);
            RootInputDataInformationEvent diEvent = new RootInputDataInformationEvent(count, serializedSplit.toByteArray());
            diEvent.setTargetIndex(count);
            ++count;
            taskEvents.add(diEvent);
        }
        this.context.setVertexParallelism(taskCount, new VertexLocationHint(CustomPartitionVertex.createTaskLocationHintsFromSplits(finalSplits.toArray(new InputSplit[finalSplits.size()]))), (Map)emMap);
        this.context.addRootInputEvents(inputName, (Collection)taskEvents);
    }

    private byte[] getBytePayload(Multimap<Integer, Integer> routingTable) throws IOException {
        CustomEdgeConfiguration edgeConf = new CustomEdgeConfiguration(routingTable.keySet().size(), routingTable);
        DataOutputBuffer dob = new DataOutputBuffer();
        edgeConf.write((DataOutput)dob);
        byte[] serialized = dob.getData();
        return serialized;
    }

    private FileSplit getFileSplitFromEvent(RootInputDataInformationEvent event) throws IOException {
        InputSplit inputSplit = null;
        if (event.getDeserializedUserPayload() != null) {
            inputSplit = (InputSplit)event.getDeserializedUserPayload();
        } else {
            MRRuntimeProtos.MRSplitProto splitProto = MRRuntimeProtos.MRSplitProto.parseFrom((byte[])event.getUserPayload());
            SerializationFactory serializationFactory = new SerializationFactory(new Configuration());
            inputSplit = MRHelpers.createOldFormatSplitFromUserPayload((MRRuntimeProtos.MRSplitProto)splitProto, (SerializationFactory)serializationFactory);
        }
        if (!(inputSplit instanceof FileSplit)) {
            throw new UnsupportedOperationException("Cannot handle splits other than FileSplit for the moment");
        }
        return (FileSplit)inputSplit;
    }

    private void setBucketNumForPath(Map<Path, List<FileSplit>> pathFileSplitsMap) {
        int bucketNum = 0;
        int fsCount = 0;
        for (Map.Entry<Path, List<FileSplit>> entry : pathFileSplitsMap.entrySet()) {
            int bucketId = bucketNum % this.numBuckets;
            for (FileSplit fsplit : entry.getValue()) {
                ++fsCount;
                this.bucketToInitialSplitMap.put((Object)bucketId, (Object)fsplit);
            }
            ++bucketNum;
        }
        LOG.info((Object)("Total number of splits counted: " + fsCount + " and total files encountered: " + pathFileSplitsMap.size()));
    }

    private void groupSplits() throws IOException {
        this.bucketToGroupedSplitMap = ArrayListMultimap.create(this.bucketToInitialSplitMap);
        if (this.conf.getBoolean(GROUP_SPLITS, true)) {
            this.estimateBucketSizes();
            Map bucketSplitMap = this.bucketToInitialSplitMap.asMap();
            Iterator i$ = bucketSplitMap.keySet().iterator();
            while (i$.hasNext()) {
                int bucketId = (Integer)i$.next();
                Collection inputSplitCollection = (Collection)bucketSplitMap.get(bucketId);
                TezMapredSplitsGrouper grouper = new TezMapredSplitsGrouper();
                InputSplit[] groupedSplits = grouper.getGroupedSplits(this.conf, inputSplitCollection.toArray(new InputSplit[0]), this.bucketToNumTaskMap.get(bucketId).intValue(), HiveInputFormat.class.getName());
                LOG.info((Object)("Original split size is " + inputSplitCollection.toArray(new InputSplit[0]).length + " grouped split size is " + groupedSplits.length));
                this.bucketToGroupedSplitMap.removeAll((Object)bucketId);
                for (InputSplit inSplit : groupedSplits) {
                    this.bucketToGroupedSplitMap.put((Object)bucketId, (Object)inSplit);
                }
            }
        }
    }

    private void estimateBucketSizes() {
        HashMap<Integer, Long> bucketSizeMap = new HashMap<Integer, Long>();
        Map bucketSplitMap = this.bucketToInitialSplitMap.asMap();
        long totalSize = 0L;
        Iterator i$ = bucketSplitMap.keySet().iterator();
        while (i$.hasNext()) {
            int bucketId = (Integer)i$.next();
            Long size = 0L;
            Collection inputSplitCollection = (Collection)bucketSplitMap.get(bucketId);
            for (FileSplit fsplit : inputSplitCollection) {
                size = size + fsplit.getLength();
                totalSize += fsplit.getLength();
            }
            bucketSizeMap.put(bucketId, size);
        }
        int totalResource = this.context.getTotalAVailableResource().getMemory();
        int taskResource = this.context.getVertexTaskResource().getMemory();
        float waves = this.conf.getFloat("tez.am.grouping.split-waves", TezConfiguration.TEZ_AM_GROUPING_SPLIT_WAVES_DEFAULT);
        int numTasks = (int)((float)totalResource * waves / (float)taskResource);
        LOG.info((Object)("Total resource: " + totalResource + " Task Resource: " + taskResource + " waves: " + waves + " total size of splits: " + totalSize + " total number of tasks: " + numTasks));
        Iterator i$2 = bucketSizeMap.keySet().iterator();
        while (i$2.hasNext()) {
            int bucketId = (Integer)i$2.next();
            int numEstimatedTasks = 0;
            if (totalSize != 0L) {
                numEstimatedTasks = (int)((long)numTasks * (Long)bucketSizeMap.get(bucketId) / totalSize);
            }
            LOG.info((Object)("Estimated number of tasks: " + numEstimatedTasks + " for bucket " + bucketId));
            if (numEstimatedTasks == 0) {
                numEstimatedTasks = 1;
            }
            this.bucketToNumTaskMap.put(bucketId, numEstimatedTasks);
        }
    }

    private static List<VertexLocationHint.TaskLocationHint> createTaskLocationHintsFromSplits(InputSplit[] oldFormatSplits) {
        Iterable iterable = Iterables.transform(Arrays.asList(oldFormatSplits), (Function)new Function<InputSplit, VertexLocationHint.TaskLocationHint>(){

            public VertexLocationHint.TaskLocationHint apply(InputSplit input) {
                try {
                    if (input.getLocations() != null) {
                        return new VertexLocationHint.TaskLocationHint(new HashSet<String>(Arrays.asList(input.getLocations())), null);
                    }
                    LOG.info((Object)"NULL Location: returning an empty location hint");
                    return new VertexLocationHint.TaskLocationHint(null, null);
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
        });
        return Lists.newArrayList((Iterable)iterable);
    }
}

