/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.tasks;

import org.apache.logging.log4j.Level;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.admin.cluster.node.tasks.TaskManagerTestCase;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.VersionInformation;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.MockLog;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.test.transport.StubbableTransport;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
import org.elasticsearch.transport.AbstractTransportRequest;
import org.elasticsearch.transport.EmptyRequest;
import org.elasticsearch.transport.NodeDisconnectedException;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponseHandler;

import java.io.Closeable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;

import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.instanceOf;

public class BanFailureLoggingTests extends TaskManagerTestCase {

    @TestLogging(reason = "testing logging at DEBUG", value = "org.elasticsearch.tasks.TaskCancellationService:DEBUG")
    public void testLogsAtDebugOnDisconnectionDuringBan() throws Exception {
        runTest((connection, requestId, action, request, options) -> {
            if (action.equals(TaskCancellationService.BAN_PARENT_ACTION_NAME)) {
                connection.close();
            }
            connection.sendRequest(requestId, action, request, options);
        },
            childNode -> List.of(
                new MockLog.SeenEventExpectation(
                    "cannot send ban",
                    TaskCancellationService.class.getName(),
                    Level.DEBUG,
                    "*cannot send ban for tasks*" + childNode.getId() + "*"
                ),
                new MockLog.SeenEventExpectation(
                    "cannot remove ban",
                    TaskCancellationService.class.getName(),
                    Level.DEBUG,
                    "*failed to remove ban for tasks*" + childNode.getId() + "*"
                )
            )
        );
    }

    @TestLogging(reason = "testing logging at DEBUG", value = "org.elasticsearch.tasks.TaskCancellationService:DEBUG")
    public void testLogsAtDebugOnDisconnectionDuringBanRemoval() throws Exception {
        final AtomicInteger banCount = new AtomicInteger();
        runTest((connection, requestId, action, request, options) -> {
            if (action.equals(TaskCancellationService.BAN_PARENT_ACTION_NAME) && banCount.incrementAndGet() >= 2) {
                connection.close();
            }
            connection.sendRequest(requestId, action, request, options);
        },
            childNode -> List.of(
                new MockLog.UnseenEventExpectation(
                    "cannot send ban",
                    TaskCancellationService.class.getName(),
                    Level.DEBUG,
                    "*cannot send ban for tasks*" + childNode.getId() + "*"
                ),
                new MockLog.SeenEventExpectation(
                    "cannot remove ban",
                    TaskCancellationService.class.getName(),
                    Level.DEBUG,
                    "*failed to remove ban for tasks*" + childNode.getId() + "*"
                )
            )
        );
    }

    private void runTest(
        StubbableTransport.SendRequestBehavior sendRequestBehavior,
        Function<DiscoveryNode, List<MockLog.LoggingExpectation>> expectations
    ) throws Exception {

        final ArrayList<Closeable> resources = new ArrayList<>(3);

        try {

            // the child task might not run, but if it does we must wait for it to be cancelled before shutting everything down
            final ReentrantLock childTaskLock = new ReentrantLock();

            final MockTransportService parentTransportService = MockTransportService.createNewService(
                Settings.EMPTY,
                VersionInformation.CURRENT,
                TransportVersion.current(),
                threadPool
            );
            resources.add(parentTransportService);
            parentTransportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(parentTransportService));
            parentTransportService.start();
            parentTransportService.acceptIncomingRequests();

            final MockTransportService childTransportService = MockTransportService.createNewService(
                Settings.EMPTY,
                VersionInformation.CURRENT,
                TransportVersion.current(),
                threadPool
            );
            resources.add(childTransportService);
            childTransportService.getTaskManager().setTaskCancellationService(new TaskCancellationService(childTransportService));
            childTransportService.registerRequestHandler(
                "internal:testAction[c]",
                threadPool.executor(ThreadPool.Names.MANAGEMENT), // busy-wait for cancellation but not on a transport thread
                (StreamInput in) -> new AbstractTransportRequest(in) {
                    @Override
                    public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
                        return new CancellableTask(id, type, action, "", parentTaskId, headers);
                    }
                },
                (request, channel, task) -> {
                    final CancellableTask cancellableTask = (CancellableTask) task;
                    if (childTaskLock.tryLock()) {
                        try {
                            assertBusy(() -> assertTrue("task " + task.getId() + " should be cancelled", cancellableTask.isCancelled()));
                        } finally {
                            childTaskLock.unlock();
                        }
                    }
                    channel.sendResponse(new TaskCancelledException("task cancelled"));
                }
            );

            childTransportService.start();
            childTransportService.acceptIncomingRequests();

            parentTransportService.addSendBehavior(sendRequestBehavior);

            AbstractSimpleTransportTestCase.connectToNode(parentTransportService, childTransportService.getLocalNode());

            final CancellableTask parentTask = (CancellableTask) parentTransportService.getTaskManager()
                .register("transport", "internal:testAction", new ParentRequest());

            parentTransportService.sendChildRequest(
                childTransportService.getLocalNode(),
                "internal:testAction[c]",
                new EmptyRequest(),
                parentTask,
                TransportRequestOptions.EMPTY,
                new ChildResponseHandler(() -> parentTransportService.getTaskManager().unregister(parentTask))
            );

            try (MockLog mockLog = MockLog.capture(TaskCancellationService.class)) {
                for (MockLog.LoggingExpectation expectation : expectations.apply(childTransportService.getLocalNode())) {
                    mockLog.addExpectation(expectation);
                }

                final PlainActionFuture<Void> cancellationFuture = new PlainActionFuture<>();
                parentTransportService.getTaskManager().cancelTaskAndDescendants(parentTask, "test", true, cancellationFuture);
                try {
                    cancellationFuture.actionGet(TimeValue.timeValueSeconds(10));
                } catch (NodeDisconnectedException e) {
                    // acceptable; we mostly ignore the result of cancellation anyway
                }

                // await since failure to remove a ban may be logged after cancellation completed
                mockLog.awaitAllExpectationsMatched();
            }

            assertTrue("child tasks did not finish in time", childTaskLock.tryLock(15, TimeUnit.SECONDS));
        } finally {
            Collections.reverse(resources);
            IOUtils.close(resources);
        }
    }

    private static class ParentRequest implements TaskAwareRequest {
        @Override
        public void setParentTask(TaskId taskId) {
            fail("setParentTask should not be called");
        }

        @Override
        public void setRequestId(long requestId) {
            fail("setRequestId should not be called");
        }

        @Override
        public TaskId getParentTask() {
            return TaskId.EMPTY_TASK_ID;
        }

        @Override
        public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
            return new CancellableTask(id, type, action, "", parentTaskId, headers);
        }
    }

    private static class ChildResponseHandler extends TransportResponseHandler.Empty {
        private final Runnable onException;

        ChildResponseHandler(Runnable onException) {
            this.onException = onException;
        }

        @Override
        public Executor executor() {
            return TransportResponseHandler.TRANSPORT_WORKER;
        }

        @Override
        public void handleResponse() {
            fail("should not get successful response");
        }

        @Override
        public void handleException(TransportException exp) {
            assertThat(exp.unwrapCause(), anyOf(instanceOf(TaskCancelledException.class), instanceOf(NodeDisconnectedException.class)));
            onException.run();
        }
    }

}
