/*
 * Decompiled with CFR 0.152.
 */
package eu.stratosphere.test.broadcastvars;

import eu.stratosphere.api.common.operators.util.UserCodeClassWrapper;
import eu.stratosphere.api.common.operators.util.UserCodeObjectWrapper;
import eu.stratosphere.api.common.operators.util.UserCodeWrapper;
import eu.stratosphere.api.common.typeutils.TypeSerializerFactory;
import eu.stratosphere.api.java.record.functions.MapFunction;
import eu.stratosphere.api.java.record.io.CsvInputFormat;
import eu.stratosphere.api.java.record.io.CsvOutputFormat;
import eu.stratosphere.configuration.Configuration;
import eu.stratosphere.core.fs.Path;
import eu.stratosphere.nephele.io.DistributionPattern;
import eu.stratosphere.nephele.io.channels.ChannelType;
import eu.stratosphere.nephele.jobgraph.AbstractJobVertex;
import eu.stratosphere.nephele.jobgraph.JobGraph;
import eu.stratosphere.nephele.jobgraph.JobGraphDefinitionException;
import eu.stratosphere.nephele.jobgraph.JobInputVertex;
import eu.stratosphere.nephele.jobgraph.JobOutputVertex;
import eu.stratosphere.nephele.jobgraph.JobTaskVertex;
import eu.stratosphere.pact.runtime.plugable.pactrecord.RecordSerializerFactory;
import eu.stratosphere.pact.runtime.shipping.ShipStrategyType;
import eu.stratosphere.pact.runtime.task.CollectorMapDriver;
import eu.stratosphere.pact.runtime.task.DriverStrategy;
import eu.stratosphere.pact.runtime.task.RegularPactTask;
import eu.stratosphere.pact.runtime.task.util.LocalStrategy;
import eu.stratosphere.pact.runtime.task.util.TaskConfig;
import eu.stratosphere.test.iterative.nephele.JobGraphUtils;
import eu.stratosphere.test.util.RecordAPITestBase;
import eu.stratosphere.types.LongValue;
import eu.stratosphere.types.Record;
import eu.stratosphere.types.Value;
import eu.stratosphere.util.Collector;
import java.io.BufferedReader;
import java.util.Collection;
import java.util.Random;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.junit.Assert;

public class BroadcastVarsNepheleITCase
extends RecordAPITestBase {
    private static final long SEED_POINTS = 3287269182979823L;
    private static final long SEED_MODELS = 1004078042382130L;
    private static final int NUM_POINTS = 10000;
    private static final int NUM_MODELS = 42;
    private static final int NUM_FEATURES = 3;
    protected String pointsPath;
    protected String modelsPath;
    protected String resultPath;

    public static final String getInputPoints(int numPoints, int numDimensions, long seed) {
        if (numPoints < 1 || numPoints > 1000000) {
            throw new IllegalArgumentException();
        }
        Random r = new Random();
        StringBuilder bld = new StringBuilder(3 * (1 + numDimensions) * numPoints);
        for (int i = 1; i <= numPoints; ++i) {
            bld.append(i);
            bld.append(' ');
            r.setSeed(seed + (long)(1000 * i));
            for (int j = 1; j <= numDimensions; ++j) {
                bld.append(r.nextInt(1000));
                bld.append(' ');
            }
            bld.append('\n');
        }
        return bld.toString();
    }

    public static final String getInputModels(int numModels, int numDimensions, long seed) {
        if (numModels < 1 || numModels > 100) {
            throw new IllegalArgumentException();
        }
        Random r = new Random();
        StringBuilder bld = new StringBuilder(3 * (1 + numDimensions) * numModels);
        for (int i = 1; i <= numModels; ++i) {
            bld.append(i);
            bld.append(' ');
            r.setSeed(seed + (long)(1000 * i));
            for (int j = 1; j <= numDimensions; ++j) {
                bld.append(r.nextInt(100));
                bld.append(' ');
            }
            bld.append('\n');
        }
        return bld.toString();
    }

    protected void preSubmit() throws Exception {
        this.pointsPath = this.createTempFile("points.txt", BroadcastVarsNepheleITCase.getInputPoints(10000, 3, 3287269182979823L));
        this.modelsPath = this.createTempFile("models.txt", BroadcastVarsNepheleITCase.getInputModels(42, 3, 1004078042382130L));
        this.resultPath = this.getTempFilePath("results");
    }

    protected JobGraph getJobGraph() throws Exception {
        return this.createJobGraphV1(this.pointsPath, this.modelsPath, this.resultPath, 4);
    }

    protected void postSubmit() throws Exception {
        int j;
        Random randPoints = new Random();
        Random randModels = new Random();
        Pattern p = Pattern.compile("(\\d+) (\\d+) (\\d+)");
        long[][] results = new long[10000][42];
        boolean[][] occurs = new boolean[10000][42];
        for (int i = 0; i < 10000; ++i) {
            for (j = 0; j < 42; ++j) {
                long actDotProd = 0L;
                randPoints.setSeed(3287269182979823L + (long)(1000 * (i + 1)));
                randModels.setSeed(1004078042382130L + (long)(1000 * (j + 1)));
                for (int z = 1; z <= 3; ++z) {
                    actDotProd += (long)(randPoints.nextInt(1000) * randModels.nextInt(100));
                }
                results[i][j] = actDotProd;
                occurs[i][j] = false;
            }
        }
        for (BufferedReader reader : this.getResultReader(this.resultPath)) {
            String line = null;
            while (null != (line = reader.readLine())) {
                Matcher m = p.matcher(line);
                Assert.assertTrue((boolean)m.matches());
                int modelId = Integer.parseInt(m.group(1));
                int pointId = Integer.parseInt(m.group(2));
                long expDotProd = Long.parseLong(m.group(3));
                Assert.assertFalse((String)("Dot product for record (" + pointId + ", " + modelId + ") occurs more than once"), (boolean)occurs[pointId - 1][modelId - 1]);
                Assert.assertEquals((String)String.format("Bad product for (%04d, %04d)", pointId, modelId), (long)expDotProd, (long)results[pointId - 1][modelId - 1]);
                occurs[pointId - 1][modelId - 1] = true;
            }
        }
        for (int i = 0; i < 10000; ++i) {
            for (j = 0; j < 42; ++j) {
                Assert.assertTrue((String)("Dot product for record (" + (i + 1) + ", " + (j + 1) + ") does not occur"), (boolean)occurs[i][j]);
            }
        }
    }

    private static JobInputVertex createPointsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) {
        CsvInputFormat pointsInFormat = new CsvInputFormat(' ', new Class[]{LongValue.class, LongValue.class, LongValue.class, LongValue.class});
        JobInputVertex pointsInput = JobGraphUtils.createInput(pointsInFormat, pointsPath, "Input[Points]", jobGraph, numSubTasks, numSubTasks);
        TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration());
        taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
        taskConfig.setOutputSerializer(serializer);
        return pointsInput;
    }

    private static JobInputVertex createModelsInput(JobGraph jobGraph, String pointsPath, int numSubTasks, TypeSerializerFactory<?> serializer) {
        CsvInputFormat modelsInFormat = new CsvInputFormat(' ', new Class[]{LongValue.class, LongValue.class, LongValue.class, LongValue.class});
        JobInputVertex modelsInput = JobGraphUtils.createInput(modelsInFormat, pointsPath, "Input[Models]", jobGraph, numSubTasks, numSubTasks);
        TaskConfig taskConfig = new TaskConfig(modelsInput.getConfiguration());
        taskConfig.addOutputShipStrategy(ShipStrategyType.BROADCAST);
        taskConfig.setOutputSerializer(serializer);
        return modelsInput;
    }

    private static JobTaskVertex createMapper(JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> serializer) {
        JobTaskVertex pointsInput = JobGraphUtils.createTask(RegularPactTask.class, "Map[DotProducts]", jobGraph, numSubTasks, numSubTasks);
        TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration());
        taskConfig.setStubWrapper((UserCodeWrapper)new UserCodeClassWrapper(DotProducts.class));
        taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
        taskConfig.setOutputSerializer(serializer);
        taskConfig.setDriver(CollectorMapDriver.class);
        taskConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP);
        taskConfig.addInputToGroup(0);
        taskConfig.setInputLocalStrategy(0, LocalStrategy.NONE);
        taskConfig.setInputSerializer(serializer, 0);
        taskConfig.setBroadcastInputName("models", 0);
        taskConfig.addBroadcastInputToGroup(0);
        taskConfig.setBroadcastInputSerializer(serializer, 0);
        return pointsInput;
    }

    private static JobOutputVertex createOutput(JobGraph jobGraph, String resultPath, int numSubTasks, TypeSerializerFactory<?> serializer) {
        JobOutputVertex output = JobGraphUtils.createFileOutput(jobGraph, "Output", numSubTasks, numSubTasks);
        TaskConfig taskConfig = new TaskConfig(output.getConfiguration());
        taskConfig.addInputToGroup(0);
        taskConfig.setInputSerializer(serializer, 0);
        CsvOutputFormat outFormat = new CsvOutputFormat("\n", " ", new Class[]{LongValue.class, LongValue.class, LongValue.class});
        outFormat.setOutputFilePath(new Path(resultPath));
        taskConfig.setStubWrapper((UserCodeWrapper)new UserCodeObjectWrapper((Object)outFormat));
        return output;
    }

    private JobGraph createJobGraphV1(String pointsPath, String centersPath, String resultPath, int numSubTasks) throws JobGraphDefinitionException {
        RecordSerializerFactory serializer = RecordSerializerFactory.get();
        JobGraph jobGraph = new JobGraph("Distance Builder");
        JobInputVertex points = BroadcastVarsNepheleITCase.createPointsInput(jobGraph, pointsPath, numSubTasks, serializer);
        JobInputVertex models = BroadcastVarsNepheleITCase.createModelsInput(jobGraph, centersPath, numSubTasks, serializer);
        JobTaskVertex mapper = BroadcastVarsNepheleITCase.createMapper(jobGraph, numSubTasks, serializer);
        JobOutputVertex output = BroadcastVarsNepheleITCase.createOutput(jobGraph, resultPath, numSubTasks, serializer);
        JobGraphUtils.connect((AbstractJobVertex)points, (AbstractJobVertex)mapper, ChannelType.NETWORK, DistributionPattern.POINTWISE);
        JobGraphUtils.connect((AbstractJobVertex)models, (AbstractJobVertex)mapper, ChannelType.NETWORK, DistributionPattern.BIPARTITE);
        JobGraphUtils.connect((AbstractJobVertex)mapper, (AbstractJobVertex)output, ChannelType.NETWORK, DistributionPattern.POINTWISE);
        points.setVertexToShareInstancesWith((AbstractJobVertex)output);
        models.setVertexToShareInstancesWith((AbstractJobVertex)output);
        mapper.setVertexToShareInstancesWith((AbstractJobVertex)output);
        return jobGraph;
    }

    public static final class DotProducts
    extends MapFunction {
        private static final long serialVersionUID = 1L;
        private final Record result = new Record(3);
        private final LongValue lft = new LongValue();
        private final LongValue rgt = new LongValue();
        private final LongValue prd = new LongValue();
        private Collection<Record> models;

        public void open(Configuration parameters) throws Exception {
            this.models = this.getRuntimeContext().getBroadcastVariable("models");
        }

        public void map(Record record, Collector<Record> out) throws Exception {
            for (Record model : this.models) {
                long product = 0L;
                for (int i = 1; i <= 3; ++i) {
                    product += ((LongValue)model.getField(i, (Value)this.lft)).getValue() * ((LongValue)record.getField(i, (Value)this.rgt)).getValue();
                }
                this.prd.setValue(product);
                this.result.copyFrom(model, new int[]{0}, new int[]{0});
                this.result.copyFrom(record, new int[]{0}, new int[]{1});
                this.result.setField(2, (Value)this.prd);
                out.collect((Object)this.result);
            }
        }
    }
}

