/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

package org.graylog.shaded.opensearch22.org.opensearch.indices.replication;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.graylog.shaded.opensearch22.org.opensearch.OpenSearchException;
import org.graylog.shaded.opensearch22.org.opensearch.action.ActionListener;
import org.graylog.shaded.opensearch22.org.opensearch.common.Nullable;
import org.graylog.shaded.opensearch22.org.opensearch.common.settings.Settings;
import org.graylog.shaded.opensearch22.org.opensearch.index.shard.IndexEventListener;
import org.graylog.shaded.opensearch22.org.opensearch.index.shard.IndexShard;
import org.graylog.shaded.opensearch22.org.opensearch.index.shard.ShardId;
import org.graylog.shaded.opensearch22.org.opensearch.indices.recovery.FileChunkRequest;
import org.graylog.shaded.opensearch22.org.opensearch.indices.recovery.RecoverySettings;
import org.graylog.shaded.opensearch22.org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
import org.graylog.shaded.opensearch22.org.opensearch.indices.replication.common.ReplicationCollection;
import org.graylog.shaded.opensearch22.org.opensearch.indices.replication.common.ReplicationCollection.ReplicationRef;
import org.graylog.shaded.opensearch22.org.opensearch.indices.replication.common.ReplicationListener;
import org.graylog.shaded.opensearch22.org.opensearch.indices.replication.common.ReplicationState;
import org.graylog.shaded.opensearch22.org.opensearch.tasks.Task;
import org.graylog.shaded.opensearch22.org.opensearch.threadpool.ThreadPool;
import org.graylog.shaded.opensearch22.org.opensearch.transport.TransportChannel;
import org.graylog.shaded.opensearch22.org.opensearch.transport.TransportRequestHandler;
import org.graylog.shaded.opensearch22.org.opensearch.transport.TransportService;

import java.util.concurrent.atomic.AtomicLong;

/**
 * Service class that orchestrates replication events on replicas.
 *
 * @opensearch.internal
 */
public final class SegmentReplicationTargetService implements IndexEventListener {

    private static final Logger logger = LogManager.getLogger(SegmentReplicationTargetService.class);

    private final ThreadPool threadPool;
    private final RecoverySettings recoverySettings;

    private final ReplicationCollection<SegmentReplicationTarget> onGoingReplications;

    private final SegmentReplicationSourceFactory sourceFactory;

    /**
     * The internal actions
     *
     * @opensearch.internal
     */
    public static class Actions {
        public static final String FILE_CHUNK = "internal:index/shard/replication/file_chunk";
    }

    public SegmentReplicationTargetService(
        final ThreadPool threadPool,
        final RecoverySettings recoverySettings,
        final TransportService transportService,
        final SegmentReplicationSourceFactory sourceFactory
    ) {
        this.threadPool = threadPool;
        this.recoverySettings = recoverySettings;
        this.onGoingReplications = new ReplicationCollection<>(logger, threadPool);
        this.sourceFactory = sourceFactory;

        transportService.registerRequestHandler(
            Actions.FILE_CHUNK,
            ThreadPool.Names.GENERIC,
            FileChunkRequest::new,
            new FileChunkTransportRequestHandler()
        );
    }

    @Override
    public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexShard, Settings indexSettings) {
        if (indexShard != null) {
            onGoingReplications.cancelForShard(shardId, "shard closed");
        }
    }

    public void startReplication(
        final ReplicationCheckpoint checkpoint,
        final IndexShard indexShard,
        final SegmentReplicationListener listener
    ) {
        startReplication(new SegmentReplicationTarget(checkpoint, indexShard, sourceFactory.get(indexShard), listener));
    }

    public void startReplication(final SegmentReplicationTarget target) {
        final long replicationId = onGoingReplications.start(target, recoverySettings.activityTimeout());
        logger.trace(() -> new ParameterizedMessage("Starting replication {}", replicationId));
        threadPool.generic().execute(new ReplicationRunner(replicationId));
    }

    /**
     * Listener that runs on changes in Replication state
     *
     * @opensearch.internal
     */
    public interface SegmentReplicationListener extends ReplicationListener {

        @Override
        default void onDone(ReplicationState state) {
            onReplicationDone((SegmentReplicationState) state);
        }

        @Override
        default void onFailure(ReplicationState state, OpenSearchException e, boolean sendShardFailure) {
            onReplicationFailure((SegmentReplicationState) state, e, sendShardFailure);
        }

        void onReplicationDone(SegmentReplicationState state);

        void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure);
    }

    /**
     * Runnable implementation to trigger a replication event.
     */
    private class ReplicationRunner implements Runnable {

        final long replicationId;

        public ReplicationRunner(long replicationId) {
            this.replicationId = replicationId;
        }

        @Override
        public void run() {
            start(replicationId);
        }
    }

    private void start(final long replicationId) {
        try (ReplicationRef<SegmentReplicationTarget> replicationRef = onGoingReplications.get(replicationId)) {
            replicationRef.get().startReplication(new ActionListener<>() {
                @Override
                public void onResponse(Void o) {
                    onGoingReplications.markAsDone(replicationId);
                }

                @Override
                public void onFailure(Exception e) {
                    onGoingReplications.fail(replicationId, new OpenSearchException("Segment Replication failed", e), true);
                }
            });
        }
    }

    private class FileChunkTransportRequestHandler implements TransportRequestHandler<FileChunkRequest> {

        // How many bytes we've copied since we last called RateLimiter.pause
        final AtomicLong bytesSinceLastPause = new AtomicLong();

        @Override
        public void messageReceived(final FileChunkRequest request, TransportChannel channel, Task task) throws Exception {
            try (ReplicationRef<SegmentReplicationTarget> ref = onGoingReplications.getSafe(request.recoveryId(), request.shardId())) {
                final SegmentReplicationTarget target = ref.get();
                final ActionListener<Void> listener = target.createOrFinishListener(channel, Actions.FILE_CHUNK, request);
                target.handleFileChunk(request, target, bytesSinceLastPause, recoverySettings.rateLimiter(), listener);
            }
        }
    }
}
