/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.pregel;

import java.util.Optional;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.stream.StreamSupport;
import org.immutables.value.Value;
import org.neo4j.gds.pregel.HitsConfigImpl;
import org.neo4j.graphalgo.annotation.Configuration;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.nodeproperties.ValueType;
import org.neo4j.graphalgo.beta.pregel.Messages;
import org.neo4j.graphalgo.beta.pregel.PregelComputation;
import org.neo4j.graphalgo.beta.pregel.PregelConfig;
import org.neo4j.graphalgo.beta.pregel.PregelSchema;
import org.neo4j.graphalgo.beta.pregel.annotation.PregelProcedure;
import org.neo4j.graphalgo.beta.pregel.context.ComputeContext;
import org.neo4j.graphalgo.beta.pregel.context.InitContext;
import org.neo4j.graphalgo.beta.pregel.context.MasterComputeContext;
import org.neo4j.graphalgo.config.GraphCreateConfig;
import org.neo4j.graphalgo.core.CypherMapWrapper;

@PregelProcedure(name="gds.alpha.hits", description="Hyperlink-Induced Topic Search (HITS) is a link analysis algorithm that rates nodes")
public class Hits
implements PregelComputation<HitsConfig> {
    static final String HUB_PROPERTY = "hub";
    static final String AUTH_PROPERTY = "auth";
    private static final String NEIGHBOR_IDS = "neighborIds";
    private final DoubleAdder globalNorm = new DoubleAdder();
    private HitsState state = HitsState.SEND_IDS;

    public PregelSchema schema() {
        return new PregelSchema.Builder().add(AUTH_PROPERTY, ValueType.DOUBLE).add(HUB_PROPERTY, ValueType.DOUBLE).add(NEIGHBOR_IDS, ValueType.LONG_ARRAY, PregelSchema.Visibility.PRIVATE).build();
    }

    public void init(InitContext<HitsConfig> context) {
        context.setNodeValue(AUTH_PROPERTY, 1.0);
        context.setNodeValue(HUB_PROPERTY, 1.0);
    }

    public void compute(ComputeContext<HitsConfig> context, Messages messages) {
        switch (this.state) {
            case SEND_IDS: {
                context.sendToNeighbors((double)context.nodeId());
                break;
            }
            case RECEIVE_IDS: {
                this.receiveIds(context, messages);
                break;
            }
            case CALCULATE_AUTHS: {
                this.calculateValue(context, messages, AUTH_PROPERTY);
                break;
            }
            case NORMALIZE_AUTHS: {
                this.normalizeAuthValue(context);
                break;
            }
            case CALCULATE_HUBS: {
                this.calculateValue(context, messages, HUB_PROPERTY);
                break;
            }
            case NORMALIZE_HUBS: {
                this.normalizeHubValue(context);
            }
        }
    }

    public void masterCompute(MasterComputeContext<HitsConfig> context) {
        if (this.state == HitsState.RECEIVE_IDS || this.state == HitsState.CALCULATE_AUTHS || this.state == HitsState.CALCULATE_HUBS) {
            double norm = this.globalNorm.sumThenReset();
            this.globalNorm.add(Math.sqrt(norm));
        } else if (this.state == HitsState.NORMALIZE_AUTHS || this.state == HitsState.NORMALIZE_HUBS) {
            this.globalNorm.reset();
        }
        this.state = this.state.advance();
    }

    private void receiveIds(ComputeContext<HitsConfig> context, Messages messages) {
        long[] neighborIds = StreamSupport.stream(messages.spliterator(), false).mapToLong(Double::longValue).toArray();
        context.setNodeValue(NEIGHBOR_IDS, neighborIds);
        int auth = neighborIds.length;
        context.setNodeValue(AUTH_PROPERTY, (double)auth);
        this.updateGlobalNorm(auth);
    }

    private void calculateValue(ComputeContext<HitsConfig> context, Messages messages, String authProperty) {
        double auth = 0.0;
        for (Double message : messages) {
            auth += message.doubleValue();
        }
        context.setNodeValue(authProperty, auth);
        this.updateGlobalNorm(auth);
    }

    private void normalizeHubValue(ComputeContext<HitsConfig> context) {
        double normalizedValue = this.normalize(context, HUB_PROPERTY);
        context.sendToNeighbors(normalizedValue);
    }

    private void normalizeAuthValue(ComputeContext<HitsConfig> context) {
        double normalizedValue = this.normalize(context, AUTH_PROPERTY);
        for (long neighbor : context.longArrayNodeValue(NEIGHBOR_IDS)) {
            context.sendTo(neighbor, normalizedValue);
        }
    }

    private void updateGlobalNorm(double value) {
        this.globalNorm.add(Math.pow(value, 2.0));
    }

    private double normalize(ComputeContext<HitsConfig> context, String property) {
        double value = context.doubleNodeValue(property);
        double norm = this.globalNorm.sum();
        double normalizedValue = value / norm;
        context.setNodeValue(property, normalizedValue);
        return normalizedValue;
    }

    private static enum HitsState {
        SEND_IDS{

            @Override
            HitsState advance() {
                return RECEIVE_IDS;
            }
        }
        ,
        RECEIVE_IDS{

            @Override
            HitsState advance() {
                return NORMALIZE_AUTHS;
            }
        }
        ,
        CALCULATE_AUTHS{

            @Override
            HitsState advance() {
                return NORMALIZE_AUTHS;
            }
        }
        ,
        NORMALIZE_AUTHS{

            @Override
            HitsState advance() {
                return CALCULATE_HUBS;
            }
        }
        ,
        CALCULATE_HUBS{

            @Override
            public HitsState advance() {
                return NORMALIZE_HUBS;
            }
        }
        ,
        NORMALIZE_HUBS{

            @Override
            HitsState advance() {
                return CALCULATE_AUTHS;
            }
        };


        abstract HitsState advance();
    }

    @ValueClass
    @Configuration
    public static interface HitsConfig
    extends PregelConfig {
        @Value
        public int hitsIterations();

        @Value.Derived
        @Configuration.Ignore
        default public int maxIterations() {
            return this.hitsIterations() * 4 + 1;
        }

        @Configuration.Ignore
        @Value.Derived
        default public boolean isAsynchronous() {
            return false;
        }

        public static HitsConfig of(String username, Optional<String> graphName, Optional<GraphCreateConfig> maybeImplicitConfig, CypherMapWrapper userConfig) {
            return new HitsConfigImpl(graphName, maybeImplicitConfig, username, userConfig);
        }
    }
}

