/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */
package org.elasticsearch.action.support.replication;

import org.apache.logging.log4j.Logger;
import org.apache.lucene.store.AlreadyClosedException;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.UnavailableShardsException;
import org.elasticsearch.action.support.ActiveShardCount;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.RetryableAction;
import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.cluster.action.shard.ShardStateAction;
import org.elasticsearch.cluster.routing.IndexShardRoutingTable;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.seqno.SequenceNumbers;
import org.elasticsearch.index.shard.ReplicationGroup;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.LongSupplier;

import static org.elasticsearch.core.Strings.format;

public class ReplicationOperation<
    Request extends ReplicationRequest<Request>,
    ReplicaRequest extends ReplicationRequest<ReplicaRequest>,
    PrimaryResultT extends ReplicationOperation.PrimaryResult<ReplicaRequest>> {
    private final Logger logger;
    private final ThreadPool threadPool;
    private final Request request;
    private final String opType;
    private final AtomicInteger totalShards = new AtomicInteger();
    private final AtomicInteger successfulShards = new AtomicInteger();
    private final Primary<Request, ReplicaRequest, PrimaryResultT> primary;
    private final Replicas<ReplicaRequest> replicasProxy;
    private final AtomicBoolean finished = new AtomicBoolean();
    private final TimeValue initialRetryBackoffBound;
    private final TimeValue retryTimeout;
    private final long primaryTerm;

    private final ActionListener<PrimaryResultT> resultListener;

    private volatile PrimaryResultT primaryResult = null;

    private final List<ReplicationResponse.ShardInfo.Failure> shardReplicaFailures = Collections.synchronizedList(new ArrayList<>());

    public ReplicationOperation(
        Request request,
        Primary<Request, ReplicaRequest, PrimaryResultT> primary,
        ActionListener<PrimaryResultT> listener,
        Replicas<ReplicaRequest> replicas,
        Logger logger,
        ThreadPool threadPool,
        String opType,
        long primaryTerm,
        TimeValue initialRetryBackoffBound,
        TimeValue retryTimeout
    ) {
        this.replicasProxy = replicas;
        this.primary = primary;
        this.resultListener = listener;
        this.logger = logger;
        this.threadPool = threadPool;
        this.request = request;
        this.opType = opType;
        this.primaryTerm = primaryTerm;
        this.initialRetryBackoffBound = initialRetryBackoffBound;
        this.retryTimeout = retryTimeout;
    }

    /**
     * The execution is based on a {@link RefCountingListener} that encapsulates the pending sub-operations in this operation. A new
     * listener is acquired when the following sub-operations start and triggered when they complete:
     * <ul>
     * <li>The operation on the primary</li>
     * <li>The operation on each replica</li>
     * <li>Coordination of the operation as a whole. This prevents the operation from terminating early if we haven't started any replica
     * operations and the primary finishes.</li>
     * </ul>
     */
    public void execute() throws Exception {
        try (var pendingActionsListener = new RefCountingListener(ActionListener.wrap((ignored) -> {
            primaryResult.setShardInfo(
                ReplicationResponse.ShardInfo.of(
                    totalShards.get(),
                    successfulShards.get(),
                    shardReplicaFailures.toArray(ReplicationResponse.NO_FAILURES)
                )
            );
            resultListener.onResponse(primaryResult);
        }, resultListener::onFailure))) {
            ActionListener.run(pendingActionsListener.acquire(), (primaryCoordinationListener) -> { // triggered when we finish coordination
                final String activeShardCountFailure = checkActiveShardCount();
                final ShardRouting primaryRouting = primary.routingEntry();
                final ShardId primaryId = primaryRouting.shardId();
                if (activeShardCountFailure != null) {
                    throw new UnavailableShardsException(
                        primaryId,
                        "{} Timeout: [{}], request: [{}]",
                        activeShardCountFailure,
                        request.timeout(),
                        request
                    );
                }

                totalShards.incrementAndGet();
                primary.perform(request, primaryCoordinationListener.delegateFailureAndWrap((l, primaryResult) -> {
                    handlePrimaryResult(primaryResult, l, pendingActionsListener);
                }));
            });
        }
    }

    private void handlePrimaryResult(
        final PrimaryResultT primaryResult,
        final ActionListener<Void> primaryCoordinationPendingActionListener,
        final RefCountingListener pendingActionsListener
    ) {
        this.primaryResult = primaryResult;
        final ReplicaRequest replicaRequest = primaryResult.replicaRequest();
        if (replicaRequest != null) {
            if (logger.isTraceEnabled()) {
                logger.trace("[{}] op [{}] completed on primary for request [{}]", primary.routingEntry().shardId(), opType, request);
            }
            final ReplicationGroup replicationGroup = primary.getReplicationGroup();

            ActionListener.run(pendingActionsListener.acquire(), primaryOperationPendingActionListener -> {
                replicasProxy.onPrimaryOperationComplete(
                    replicaRequest,
                    replicationGroup.getRoutingTable(),
                    ActionListener.wrap(ignored -> primaryOperationPendingActionListener.onResponse(null), exception -> {
                        totalShards.incrementAndGet();
                        shardReplicaFailures.add(
                            new ReplicationResponse.ShardInfo.Failure(
                                primary.routingEntry().shardId(),
                                null,
                                exception,
                                ExceptionsHelper.status(exception),
                                false
                            )
                        );
                        primaryOperationPendingActionListener.onResponse(null);
                    })
                );
            });

            // we have to get the replication group after successfully indexing into the primary in order to honour recovery semantics.
            // we have to make sure that every operation indexed into the primary after recovery start will also be replicated
            // to the recovery target. If we used an old replication group, we may miss a recovery that has started since then.
            // we also have to make sure to get the global checkpoint before the replication group, to ensure that the global checkpoint
            // is valid for this replication group. If we would sample in the reverse, the global checkpoint might be based on a subset
            // of the sampled replication group, and advanced further than what the given replication group would allow it to.
            // This would entail that some shards could learn about a global checkpoint that would be higher than its local checkpoint.
            final long globalCheckpoint = primary.computedGlobalCheckpoint();
            // we have to capture the max_seq_no_of_updates after this request was completed on the primary to make sure the value of
            // max_seq_no_of_updates on replica when this request is executed is at least the value on the primary when it was executed
            // on.
            final long maxSeqNoOfUpdatesOrDeletes = primary.maxSeqNoOfUpdatesOrDeletes();
            assert maxSeqNoOfUpdatesOrDeletes != SequenceNumbers.UNASSIGNED_SEQ_NO : "seqno_of_updates still uninitialized";
            final PendingReplicationActions pendingReplicationActions = primary.getPendingReplicationActions();
            markUnavailableShardsAsStale(replicaRequest, replicationGroup, pendingActionsListener);
            performOnReplicas(
                replicaRequest,
                globalCheckpoint,
                maxSeqNoOfUpdatesOrDeletes,
                replicationGroup,
                pendingReplicationActions,
                pendingActionsListener
            );
        }
        primaryResult.runPostReplicationActions(new ActionListener<>() {

            @Override
            public void onResponse(Void aVoid) {
                successfulShards.incrementAndGet();
                updateCheckPoints(
                    primary.routingEntry(),
                    primary::localCheckpoint,
                    primary::globalCheckpoint,
                    () -> primaryCoordinationPendingActionListener.onResponse(null)
                );
            }

            @Override
            public void onFailure(Exception e) {
                logger.trace("[{}] op [{}] post replication actions failed for [{}]", primary.routingEntry().shardId(), opType, request);
                // TODO: fail shard? This will otherwise have the local / global checkpoint info lagging, or possibly have replicas
                // go out of sync with the primary
                // We update the checkpoints since a refresh might fail but the operations could be safely persisted, in the case that the
                // fsync failed the local checkpoint won't advance and the engine will be marked as failed when the next indexing operation
                // is appended into the translog.
                updateCheckPoints(
                    primary.routingEntry(),
                    primary::localCheckpoint,
                    primary::globalCheckpoint,
                    () -> primaryCoordinationPendingActionListener.onFailure(e)
                );
            }
        });
    }

    private void markUnavailableShardsAsStale(
        final ReplicaRequest replicaRequest,
        final ReplicationGroup replicationGroup,
        final RefCountingListener pendingActionsListener
    ) {
        // if inSyncAllocationIds contains allocation ids of shards that don't exist in RoutingTable, mark copies as stale
        for (String allocationId : replicationGroup.getUnavailableInSyncShards()) {
            ActionListener.run(pendingActionsListener.acquire(), (staleCopyPendingActionListener) -> {
                replicasProxy.markShardCopyAsStaleIfNeeded(
                    replicaRequest.shardId(),
                    allocationId,
                    primaryTerm,
                    staleCopyPendingActionListener.delegateResponse((l, e) -> onNoLongerPrimary(e, l))
                );
            });
        }
    }

    private void performOnReplicas(
        final ReplicaRequest replicaRequest,
        final long globalCheckpoint,
        final long maxSeqNoOfUpdatesOrDeletes,
        final ReplicationGroup replicationGroup,
        final PendingReplicationActions pendingReplicationActions,
        final RefCountingListener pendingActionsListener
    ) {
        // for total stats, add number of unassigned shards and
        // number of initializing shards that are not ready yet to receive operations (recovery has not opened engine yet on the target)
        totalShards.addAndGet(replicationGroup.getSkippedShards().size());

        final ShardRouting primaryRouting = primary.routingEntry();

        for (final ShardRouting shard : replicationGroup.getReplicationTargets()) {
            if (shard.isSameAllocation(primaryRouting) == false) {
                performOnReplica(
                    shard,
                    replicaRequest,
                    globalCheckpoint,
                    maxSeqNoOfUpdatesOrDeletes,
                    pendingReplicationActions,
                    pendingActionsListener
                );
            }
        }
    }

    private void performOnReplica(
        final ShardRouting shard,
        final ReplicaRequest replicaRequest,
        final long globalCheckpoint,
        final long maxSeqNoOfUpdatesOrDeletes,
        final PendingReplicationActions pendingReplicationActions,
        final RefCountingListener pendingActionsListener
    ) {
        assert shard.isPromotableToPrimary() : "only promotable shards should receive replication requests";
        if (logger.isTraceEnabled()) {
            logger.trace("[{}] sending op [{}] to replica {} for request [{}]", shard.shardId(), opType, shard, replicaRequest);
        }
        totalShards.incrementAndGet();
        ActionListener.run(pendingActionsListener.acquire(), (replicationPendingActionListener) -> {
            final ActionListener<ReplicaResponse> replicationListener = new ActionListener<>() {
                @Override
                public void onResponse(ReplicaResponse response) {
                    successfulShards.incrementAndGet();
                    updateCheckPoints(
                        shard,
                        response::localCheckpoint,
                        response::globalCheckpoint,
                        () -> replicationPendingActionListener.onResponse(null)
                    );
                }

                @Override
                public void onFailure(Exception replicaException) {
                    logger.trace(
                        () -> format(
                            "[%s] failure while performing [%s] on replica %s, request [%s]",
                            shard.shardId(),
                            opType,
                            shard,
                            replicaRequest
                        ),
                        replicaException
                    );
                    // Only report "critical" exceptions - TODO: Reach out to the master node to get the latest shard state then report.
                    if (TransportActions.isShardNotAvailableException(replicaException) == false) {
                        RestStatus restStatus = ExceptionsHelper.status(replicaException);
                        shardReplicaFailures.add(
                            new ReplicationResponse.ShardInfo.Failure(
                                shard.shardId(),
                                shard.currentNodeId(),
                                replicaException,
                                restStatus,
                                false
                            )
                        );
                    }
                    String message = String.format(Locale.ROOT, "failed to perform %s on replica %s", opType, shard);
                    replicasProxy.failShardIfNeeded(
                        shard,
                        primaryTerm,
                        message,
                        replicaException,
                        replicationPendingActionListener.delegateResponse((l, e) -> onNoLongerPrimary(e, l))
                    );
                }

                @Override
                public String toString() {
                    return "[" + replicaRequest + "][" + shard + "]";
                }
            };

            final String allocationId = shard.allocationId().getId();
            final RetryableAction<ReplicaResponse> replicationAction = new RetryableAction<>(
                logger,
                threadPool,
                initialRetryBackoffBound,
                retryTimeout,
                replicationListener,
                EsExecutors.DIRECT_EXECUTOR_SERVICE
            ) {

                @Override
                public void tryAction(ActionListener<ReplicaResponse> listener) {
                    replicasProxy.performOn(shard, replicaRequest, primaryTerm, globalCheckpoint, maxSeqNoOfUpdatesOrDeletes, listener);
                }

                @Override
                public void onFinished() {
                    super.onFinished();
                    pendingReplicationActions.removeReplicationAction(allocationId, this);
                }

                @Override
                public boolean shouldRetry(Exception e) {
                    final Throwable cause = ExceptionsHelper.unwrapCause(e);
                    return cause instanceof CircuitBreakingException
                        || cause instanceof EsRejectedExecutionException
                        || cause instanceof ConnectTransportException;
                }
            };

            pendingReplicationActions.addPendingAction(allocationId, replicationAction);
            replicationAction.run();
        });
    }

    private void updateCheckPoints(
        ShardRouting shard,
        LongSupplier localCheckpointSupplier,
        LongSupplier globalCheckpointSupplier,
        Runnable onCompletion
    ) {
        boolean forked = false;
        try {
            primary.updateLocalCheckpointForShard(shard.allocationId().getId(), localCheckpointSupplier.getAsLong());
            primary.updateGlobalCheckpointForShard(shard.allocationId().getId(), globalCheckpointSupplier.getAsLong());
        } catch (final AlreadyClosedException e) {
            // the index was deleted or this shard was never activated after a relocation; fall through and finish normally
        } catch (final Exception e) {
            threadPool.executor(ThreadPool.Names.WRITE).execute(new AbstractRunnable() {
                @Override
                public void onFailure(Exception e) {
                    assert false : e;
                }

                @Override
                public boolean isForceExecution() {
                    return true;
                }

                @Override
                protected void doRun() {
                    // fail the primary but fall through and let the rest of operation processing complete
                    primary.failShard(String.format(Locale.ROOT, "primary failed updating local checkpoint for replica %s", shard), e);
                }

                @Override
                public void onAfter() {
                    onCompletion.run();
                }
            });
            forked = true;
        } finally {
            if (forked == false) {
                onCompletion.run();
            }
        }
    }

    private void onNoLongerPrimary(Exception failure, ActionListener<Void> listener) {
        ActionListener.run(listener, (l) -> {
            final Throwable cause = ExceptionsHelper.unwrapCause(failure);
            final boolean nodeIsClosing = cause instanceof NodeClosedException;
            if (nodeIsClosing) {
                // We prefer not to fail the primary to avoid unnecessary warning log
                // when the node with the primary shard is gracefully shutting down.
                l.onFailure(
                    new RetryOnPrimaryException(
                        primary.routingEntry().shardId(),
                        String.format(
                            Locale.ROOT,
                            "node with primary [%s] is shutting down while failing replica shard",
                            primary.routingEntry()
                        ),
                        failure
                    )
                );
            } else {
                assert failure instanceof ShardStateAction.NoLongerPrimaryShardException : failure;
                threadPool.executor(ThreadPool.Names.WRITE).execute(new AbstractRunnable() {
                    @Override
                    protected void doRun() {
                        // we are no longer the primary, fail ourselves and start over
                        final var message = String.format(
                            Locale.ROOT,
                            "primary shard [%s] was demoted while failing replica shard",
                            primary.routingEntry()
                        );
                        primary.failShard(message, failure);
                        l.onFailure(new RetryOnPrimaryException(primary.routingEntry().shardId(), message, failure));
                    }

                    @Override
                    public boolean isForceExecution() {
                        return true;
                    }

                    @Override
                    public void onFailure(Exception e) {
                        e.addSuppressed(failure);
                        assert false : e;
                        logger.error(() -> "unexpected failure while failing primary [" + primary.routingEntry() + "]", e);
                        l.onFailure(
                            new RetryOnPrimaryException(
                                primary.routingEntry().shardId(),
                                String.format(Locale.ROOT, "unexpected failure while failing primary [%s]", primary.routingEntry()),
                                e
                            )
                        );
                    }
                });
            }
        });
    }

    /**
     * Checks whether we can perform a write based on the required active shard count setting.
     * Returns **null* if OK to proceed, or a string describing the reason to stop
     */
    protected String checkActiveShardCount() {
        final ShardId shardId = primary.routingEntry().shardId();
        final ActiveShardCount waitForActiveShards = request.waitForActiveShards();
        if (waitForActiveShards == ActiveShardCount.NONE) {
            return null;  // not waiting for any shards
        }
        final IndexShardRoutingTable shardRoutingTable = primary.getReplicationGroup().getRoutingTable();
        ActiveShardCount.EnoughShards enoughShardsActive = waitForActiveShards.enoughShardsActive(shardRoutingTable);
        if (enoughShardsActive.enoughShards()) {
            return null;
        } else {
            final String resolvedShards = waitForActiveShards == ActiveShardCount.ALL
                ? Integer.toString(shardRoutingTable.size())
                : waitForActiveShards.toString();
            logger.trace(
                "[{}] not enough active copies to meet shard count of [{}] (have {}, needed {}), scheduling a retry. op [{}], "
                    + "request [{}]",
                shardId,
                waitForActiveShards,
                enoughShardsActive.currentActiveShards(),
                resolvedShards,
                opType,
                request
            );
            return "Not enough active copies to meet shard count of ["
                + waitForActiveShards
                + "] (have "
                + enoughShardsActive.currentActiveShards()
                + ", needed "
                + resolvedShards
                + ").";
        }
    }

    /**
     * An encapsulation of an operation that is to be performed on the primary shard
     */
    public interface Primary<
        RequestT extends ReplicationRequest<RequestT>,
        ReplicaRequestT extends ReplicationRequest<ReplicaRequestT>,
        PrimaryResultT extends PrimaryResult<ReplicaRequestT>> {

        /**
         * routing entry for this primary
         */
        ShardRouting routingEntry();

        /**
         * Fail the primary shard.
         *
         * @param message   the failure message
         * @param exception the exception that triggered the failure
         */
        void failShard(String message, Exception exception);

        /**
         * Performs the given request on this primary. Yes, this returns as soon as it can with the request for the replicas and calls a
         * listener when the primary request is completed. Yes, the primary request might complete before the method returns. Yes, it might
         * also complete after. Deal with it.
         *
         * @param request the request to perform
         * @param listener result listener
         */
        void perform(RequestT request, ActionListener<PrimaryResultT> listener);

        /**
         * Notifies the primary of a local checkpoint for the given allocation.
         *
         * Note: The primary will use this information to advance the global checkpoint if possible.
         *
         * @param allocationId allocation ID of the shard corresponding to the supplied local checkpoint
         * @param checkpoint the *local* checkpoint for the shard
         */
        void updateLocalCheckpointForShard(String allocationId, long checkpoint);

        /**
         * Update the local knowledge of the global checkpoint for the specified allocation ID.
         *
         * @param allocationId     the allocation ID to update the global checkpoint for
         * @param globalCheckpoint the global checkpoint
         */
        void updateGlobalCheckpointForShard(String allocationId, long globalCheckpoint);

        /**
         * Returns the persisted local checkpoint on the primary shard.
         *
         * @return the local checkpoint
         */
        long localCheckpoint();

        /**
         * Returns the global checkpoint computed on the primary shard.
         *
         * @return the computed global checkpoint
         */
        long computedGlobalCheckpoint();

        /**
         * Returns the persisted global checkpoint on the primary shard.
         *
         * @return the persisted global checkpoint
         */
        long globalCheckpoint();

        /**
         * Returns the maximum seq_no of updates (index operations overwrite Lucene) or deletes on the primary.
         * This value must be captured after the execution of a replication request on the primary is completed.
         */
        long maxSeqNoOfUpdatesOrDeletes();

        /**
         * Returns the current replication group on the primary shard
         *
         * @return the replication group
         */
        ReplicationGroup getReplicationGroup();

        /**
         * Returns the pending replication actions on the primary shard
         *
         * @return the pending replication actions
         */
        PendingReplicationActions getPendingReplicationActions();
    }

    /**
     * An encapsulation of an operation that will be executed on the replica shards, if present.
     */
    public interface Replicas<RequestT extends ReplicationRequest<RequestT>> {

        /**
         * Performs the specified request on the specified replica.
         *
         * @param replica                    the shard this request should be executed on
         * @param replicaRequest             the operation to perform
         * @param primaryTerm                the primary term
         * @param globalCheckpoint           the global checkpoint on the primary
         * @param maxSeqNoOfUpdatesOrDeletes the max seq_no of updates (index operations overwriting Lucene) or deletes on primary
         *                                   after this replication was executed on it.
         * @param listener                   callback for handling the response or failure
         */
        void performOn(
            ShardRouting replica,
            RequestT replicaRequest,
            long primaryTerm,
            long globalCheckpoint,
            long maxSeqNoOfUpdatesOrDeletes,
            ActionListener<ReplicaResponse> listener
        );

        /**
         * Fail the specified shard if needed, removing it from the current set
         * of active shards. Whether a failure is needed is left up to the
         * implementation.
         *
         * @param replica      shard to fail
         * @param primaryTerm  the primary term
         * @param message      a (short) description of the reason
         * @param exception    the original exception which caused the ReplicationOperation to request the shard to be failed
         * @param listener     a listener that will be notified when the failing shard has been removed from the in-sync set
         */
        void failShardIfNeeded(ShardRouting replica, long primaryTerm, String message, Exception exception, ActionListener<Void> listener);

        /**
         * Marks shard copy as stale if needed, removing its allocation id from
         * the set of in-sync allocation ids. Whether marking as stale is needed
         * is left up to the implementation.
         *
         * @param shardId      shard id
         * @param allocationId allocation id to remove from the set of in-sync allocation ids
         * @param primaryTerm  the primary term
         * @param listener     a listener that will be notified when the failing shard has been removed from the in-sync set
         */
        void markShardCopyAsStaleIfNeeded(ShardId shardId, String allocationId, long primaryTerm, ActionListener<Void> listener);

        /**
         * Optional custom logic to execute when the primary operation is complete, before sending the replica requests.
         *
         * @param replicaRequest             the operation that will be performed on replicas
         * @param indexShardRoutingTable     the replication's group index shard routing table
         * @param listener                   callback for handling the response or failure
         */
        default void onPrimaryOperationComplete(
            RequestT replicaRequest,
            IndexShardRoutingTable indexShardRoutingTable,
            ActionListener<Void> listener
        ) {
            listener.onResponse(null);
        }
    }

    /**
     * An interface to encapsulate the metadata needed from replica shards when they respond to operations performed on them.
     */
    public interface ReplicaResponse {

        /**
         * The persisted local checkpoint for the shard.
         *
         * @return the persisted local checkpoint
         **/
        long localCheckpoint();

        /**
         * The persisted global checkpoint for the shard.
         *
         * @return the persisted global checkpoint
         **/
        long globalCheckpoint();

    }

    public static final class RetryOnPrimaryException extends ElasticsearchException {
        public RetryOnPrimaryException(ShardId shardId, String msg) {
            this(shardId, msg, null);
        }

        RetryOnPrimaryException(ShardId shardId, String msg, Throwable cause) {
            super(msg, cause);
            setShard(shardId);
        }

        public RetryOnPrimaryException(StreamInput in) throws IOException {
            super(in);
        }
    }

    public interface PrimaryResult<RequestT extends ReplicationRequest<RequestT>> {

        /**
         * @return null if no operation needs to be sent to a replica
         * (for example when the operation failed on the primary due to a parsing exception)
         */
        @Nullable
        RequestT replicaRequest();

        void setShardInfo(ReplicationResponse.ShardInfo shardInfo);

        /**
         * Run actions to be triggered post replication
         * @param listener callback that is invoked after post replication actions have completed
         * */
        void runPostReplicationActions(ActionListener<Void> listener);
    }

}
