/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.cluster.structure;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ESAllocationTestCase;
import org.elasticsearch.cluster.ProjectState;
import org.elasticsearch.cluster.TestShardRoutingRoleStrategies;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.routing.GlobalRoutingTable;
import org.elasticsearch.cluster.routing.GlobalRoutingTableTestHelper;
import org.elasticsearch.cluster.routing.OperationRouting;
import org.elasticsearch.cluster.routing.RotationShardShuffler;
import org.elasticsearch.cluster.routing.RoutingTable;
import org.elasticsearch.cluster.routing.ShardIterator;
import org.elasticsearch.cluster.routing.ShardRouting;
import org.elasticsearch.cluster.routing.ShardShuffler;
import org.elasticsearch.cluster.routing.ShardsIterator;
import org.elasticsearch.cluster.routing.allocation.AllocationService;
import org.elasticsearch.cluster.routing.allocation.decider.ClusterRebalanceAllocationDecider;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.shard.ShardId;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.StreamSupport;

import static java.util.Collections.singletonMap;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;

public class RoutingIteratorTests extends ESAllocationTestCase {
    public void testEmptyIterator() {
        ShardShuffler shuffler = new RotationShardShuffler(0);
        ShardIterator shardIterator = new ShardIterator(
            new ShardId("test1", "_na_", 0),
            shuffler.shuffle(Collections.<ShardRouting>emptyList())
        );
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));

        shardIterator = new ShardIterator(new ShardId("test1", "_na_", 0), shuffler.shuffle(Collections.<ShardRouting>emptyList()));
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));

        shardIterator = new ShardIterator(new ShardId("test1", "_na_", 0), shuffler.shuffle(Collections.<ShardRouting>emptyList()));
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));

        shardIterator = new ShardIterator(new ShardId("test1", "_na_", 0), shuffler.shuffle(Collections.<ShardRouting>emptyList()));
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
    }

    public void testIterator1() {
        ProjectMetadata metadata = ProjectMetadata.builder(randomProjectIdOrDefault())
            .put(IndexMetadata.builder("test1").settings(settings(IndexVersion.current())).numberOfShards(1).numberOfReplicas(2))
            .build();
        RoutingTable routingTable = RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY)
            .addAsNew(metadata.index("test1"))
            .build();

        ShardIterator shardIterator = routingTable.index("test1").shard(0).shardsIt(0);
        assertThat(shardIterator.size(), equalTo(3));
        ShardRouting shardRouting1 = shardIterator.nextOrNull();
        assertThat(shardRouting1, notNullValue());
        assertThat(shardIterator.remaining(), equalTo(2));
        ShardRouting shardRouting2 = shardIterator.nextOrNull();
        assertThat(shardRouting2, notNullValue());
        assertThat(shardIterator.remaining(), equalTo(1));
        assertThat(shardRouting2, not(sameInstance(shardRouting1)));
        ShardRouting shardRouting3 = shardIterator.nextOrNull();
        assertThat(shardRouting3, notNullValue());
        assertThat(shardRouting3, not(sameInstance(shardRouting1)));
        assertThat(shardRouting3, not(sameInstance(shardRouting2)));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
    }

    public void testIterator2() {
        ProjectMetadata metadata = ProjectMetadata.builder(randomProjectIdOrDefault())
            .put(IndexMetadata.builder("test1").settings(settings(IndexVersion.current())).numberOfShards(1).numberOfReplicas(1))
            .put(IndexMetadata.builder("test2").settings(settings(IndexVersion.current())).numberOfShards(1).numberOfReplicas(1))
            .build();

        RoutingTable routingTable = RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY)
            .addAsNew(metadata.index("test1"))
            .addAsNew(metadata.index("test2"))
            .build();

        ShardIterator shardIterator = routingTable.index("test1").shard(0).shardsIt(0);
        assertThat(shardIterator.size(), equalTo(2));
        ShardRouting shardRouting1 = shardIterator.nextOrNull();
        assertThat(shardRouting1, notNullValue());
        assertThat(shardIterator.remaining(), equalTo(1));
        ShardRouting shardRouting2 = shardIterator.nextOrNull();
        assertThat(shardRouting2, notNullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardRouting2, not(sameInstance(shardRouting1)));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.remaining(), equalTo(0));

        shardIterator = routingTable.index("test1").shard(0).shardsIt(1);
        assertThat(shardIterator.size(), equalTo(2));
        ShardRouting shardRouting3 = shardIterator.nextOrNull();
        assertThat(shardRouting1, notNullValue());
        ShardRouting shardRouting4 = shardIterator.nextOrNull();
        assertThat(shardRouting2, notNullValue());
        assertThat(shardRouting2, not(sameInstance(shardRouting1)));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.nextOrNull(), nullValue());

        assertThat(shardRouting1, not(sameInstance(shardRouting3)));
        assertThat(shardRouting2, not(sameInstance(shardRouting4)));
        assertThat(shardRouting1, sameInstance(shardRouting4));
        assertThat(shardRouting2, sameInstance(shardRouting3));

        shardIterator = routingTable.index("test1").shard(0).shardsIt(2);
        assertThat(shardIterator.size(), equalTo(2));
        ShardRouting shardRouting5 = shardIterator.nextOrNull();
        assertThat(shardRouting5, notNullValue());
        ShardRouting shardRouting6 = shardIterator.nextOrNull();
        assertThat(shardRouting6, notNullValue());
        assertThat(shardRouting6, not(sameInstance(shardRouting5)));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.nextOrNull(), nullValue());

        assertThat(shardRouting5, sameInstance(shardRouting1));
        assertThat(shardRouting6, sameInstance(shardRouting2));

        shardIterator = routingTable.index("test1").shard(0).shardsIt(3);
        assertThat(shardIterator.size(), equalTo(2));
        ShardRouting shardRouting7 = shardIterator.nextOrNull();
        assertThat(shardRouting7, notNullValue());
        ShardRouting shardRouting8 = shardIterator.nextOrNull();
        assertThat(shardRouting8, notNullValue());
        assertThat(shardRouting8, not(sameInstance(shardRouting7)));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.nextOrNull(), nullValue());

        assertThat(shardRouting7, sameInstance(shardRouting3));
        assertThat(shardRouting8, sameInstance(shardRouting4));

        shardIterator = routingTable.index("test1").shard(0).shardsIt(4);
        assertThat(shardIterator.size(), equalTo(2));
        ShardRouting shardRouting9 = shardIterator.nextOrNull();
        assertThat(shardRouting9, notNullValue());
        ShardRouting shardRouting10 = shardIterator.nextOrNull();
        assertThat(shardRouting10, notNullValue());
        assertThat(shardRouting10, not(sameInstance(shardRouting9)));
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardIterator.nextOrNull(), nullValue());

        assertThat(shardRouting9, sameInstance(shardRouting5));
        assertThat(shardRouting10, sameInstance(shardRouting6));
    }

    public void testRandomRouting() {
        ProjectMetadata metadata = ProjectMetadata.builder(randomProjectIdOrDefault())
            .put(IndexMetadata.builder("test1").settings(settings(IndexVersion.current())).numberOfShards(1).numberOfReplicas(1))
            .put(IndexMetadata.builder("test2").settings(settings(IndexVersion.current())).numberOfShards(1).numberOfReplicas(1))
            .build();

        RoutingTable routingTable = RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY)
            .addAsNew(metadata.index("test1"))
            .addAsNew(metadata.index("test2"))
            .build();

        ShardIterator shardIterator = routingTable.index("test1").shard(0).shardsRandomIt();
        ShardRouting shardRouting1 = shardIterator.nextOrNull();
        assertThat(shardRouting1, notNullValue());
        assertThat(shardIterator.nextOrNull(), notNullValue());
        assertThat(shardIterator.nextOrNull(), nullValue());

        shardIterator = routingTable.index("test1").shard(0).shardsRandomIt();
        ShardRouting shardRouting2 = shardIterator.nextOrNull();
        assertThat(shardRouting2, notNullValue());
        ShardRouting shardRouting3 = shardIterator.nextOrNull();
        assertThat(shardRouting3, notNullValue());
        assertThat(shardIterator.nextOrNull(), nullValue());
        assertThat(shardRouting1, not(sameInstance(shardRouting2)));
        assertThat(shardRouting1, sameInstance(shardRouting3));
    }

    public void testNodeSelectorRouting() {
        AllocationService strategy = createAllocationService(
            Settings.builder()
                .put("cluster.routing.allocation.node_concurrent_recoveries", 10)
                .put(ClusterRebalanceAllocationDecider.CLUSTER_ROUTING_ALLOCATION_ALLOW_REBALANCE_SETTING.getKey(), "always")
                .build()
        );

        ProjectId projectId = randomProjectIdOrDefault();
        ProjectMetadata metadata = ProjectMetadata.builder(projectId)
            .put(IndexMetadata.builder("test").settings(settings(IndexVersion.current())).numberOfShards(1).numberOfReplicas(1))
            .build();

        RoutingTable routingTable = RoutingTable.builder(TestShardRoutingRoleStrategies.DEFAULT_ROLE_ONLY)
            .addAsNew(metadata.index("test"))
            .build();

        ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT)
            .putProjectMetadata(metadata)
            .routingTable(GlobalRoutingTable.builder().put(projectId, routingTable).build())
            .nodes(
                DiscoveryNodes.builder()
                    .add(newNode("fred", "node1", singletonMap("disk", "ebs")))
                    .add(newNode("barney", "node2", singletonMap("disk", "ephemeral")))
                    .localNodeId("node1")
            )
            .build();

        clusterState = strategy.reroute(clusterState, "reroute", ActionListener.noop());

        clusterState = startInitializingShardsAndReroute(strategy, clusterState);

        assertThat(
            getShardNodeIds(
                clusterState.globalRoutingTable()
                    .routingTable(projectId)
                    .index("test")
                    .shard(0)
                    .onlyNodeSelectorActiveInitializingShardsIt("disk:ebs", clusterState.nodes())
            ),
            contains("node1")
        );

        assertThat(
            getShardNodeIds(
                clusterState.globalRoutingTable()
                    .routingTable(projectId)
                    .index("test")
                    .shard(0)
                    .onlyNodeSelectorActiveInitializingShardsIt("dis*:eph*", clusterState.nodes())
            ),
            contains("node2")
        );

        assertThat(
            getShardNodeIds(
                clusterState.globalRoutingTable()
                    .routingTable(projectId)
                    .index("test")
                    .shard(0)
                    .onlyNodeSelectorActiveInitializingShardsIt("fred", clusterState.nodes())
            ),
            contains("node1")
        );

        assertThat(
            getShardNodeIds(
                clusterState.globalRoutingTable()
                    .routingTable(projectId)
                    .index("test")
                    .shard(0)
                    .onlyNodeSelectorActiveInitializingShardsIt("bar*", clusterState.nodes())
            ),
            contains("node2")
        );

        var nodeIds = getShardNodeIds(
            clusterState.globalRoutingTable()
                .routingTable(projectId)
                .index("test")
                .shard(0)
                .onlyNodeSelectorActiveInitializingShardsIt(new String[] { "disk:eph*", "disk:ebs" }, clusterState.nodes())
        );
        assertThat(nodeIds, containsInAnyOrder("node1", "node2"));

        assertThat(
            getShardNodeIds(
                clusterState.globalRoutingTable()
                    .routingTable(projectId)
                    .index("test")
                    .shard(0)
                    .onlyNodeSelectorActiveInitializingShardsIt(new String[] { "disk:*", "invalid_name" }, clusterState.nodes())
            ),
            equalTo(nodeIds) // order is not deterministic but needs to be consistent across the queries
        );

        assertThat(
            getShardNodeIds(
                clusterState.globalRoutingTable()
                    .routingTable(projectId)
                    .index("test")
                    .shard(0)
                    .onlyNodeSelectorActiveInitializingShardsIt(new String[] { "disk:*", "disk:*" }, clusterState.nodes())
            ),
            equalTo(nodeIds) // order is not deterministic but needs to be consistent across the queries
        );

        try {
            clusterState.globalRoutingTable()
                .routingTable(projectId)
                .index("test")
                .shard(0)
                .onlyNodeSelectorActiveInitializingShardsIt("welma", clusterState.nodes());
            fail("should have raised illegalArgumentException");
        } catch (IllegalArgumentException illegal) {
            // expected exception
        }

        assertThat(
            getShardNodeIds(
                clusterState.globalRoutingTable()
                    .routingTable(projectId)
                    .index("test")
                    .shard(0)
                    .onlyNodeSelectorActiveInitializingShardsIt("fred", clusterState.nodes())
            ),
            contains("node1")
        );
    }

    private static List<String> getShardNodeIds(ShardsIterator iterator) {
        return StreamSupport.stream(iterator.spliterator(), false).map(ShardRouting::currentNodeId).toList();
    }

    public void testShardsAndPreferNodeRouting() {
        AllocationService strategy = createAllocationService(
            Settings.builder().put("cluster.routing.allocation.node_concurrent_recoveries", 10).build()
        );

        ProjectId projectId = randomUniqueProjectId();
        Metadata metadata = Metadata.builder()
            .put(
                ProjectMetadata.builder(projectId)
                    .put(IndexMetadata.builder("test").settings(settings(IndexVersion.current())).numberOfShards(5).numberOfReplicas(1))
            )
            .build();

        GlobalRoutingTable routingTable = GlobalRoutingTableTestHelper.buildRoutingTable(metadata, RoutingTable.Builder::addAsNew);

        ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT).metadata(metadata).routingTable(routingTable).build();

        clusterState = ClusterState.builder(clusterState)
            .nodes(DiscoveryNodes.builder().add(newNode("node1")).add(newNode("node2")).localNodeId("node1"))
            .build();
        clusterState = strategy.reroute(clusterState, "reroute", ActionListener.noop());

        clusterState = startInitializingShardsAndReroute(strategy, clusterState);
        clusterState = startInitializingShardsAndReroute(strategy, clusterState);
        ProjectState project = clusterState.projectState(projectId);

        OperationRouting operationRouting = new OperationRouting(
            Settings.EMPTY,
            new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)
        );

        List<ShardIterator> shardIterators = operationRouting.searchShards(project, new String[] { "test" }, null, "_shards:0");
        assertThat(shardIterators.size(), equalTo(1));
        assertThat(shardIterators.iterator().next().shardId().id(), equalTo(0));

        shardIterators = operationRouting.searchShards(project, new String[] { "test" }, null, "_shards:1");
        assertThat(shardIterators.size(), equalTo(1));
        assertThat(shardIterators.iterator().next().shardId().id(), equalTo(1));

        // check node preference, first without preference to see they switch
        shardIterators = operationRouting.searchShards(project, new String[] { "test" }, null, "_shards:0|");
        assertThat(shardIterators.size(), equalTo(1));
        assertThat(shardIterators.iterator().next().shardId().id(), equalTo(0));
        String firstRoundNodeId = shardIterators.iterator().next().nextOrNull().currentNodeId();

        shardIterators = operationRouting.searchShards(project, new String[] { "test" }, null, "_shards:0");
        assertThat(shardIterators.size(), equalTo(1));
        assertThat(shardIterators.iterator().next().shardId().id(), equalTo(0));
        assertThat(shardIterators.iterator().next().nextOrNull().currentNodeId(), not(equalTo(firstRoundNodeId)));

        shardIterators = operationRouting.searchShards(project, new String[] { "test" }, null, "_shards:0|_prefer_nodes:node1");
        assertThat(shardIterators.size(), equalTo(1));
        assertThat(shardIterators.iterator().next().shardId().id(), equalTo(0));
        assertThat(shardIterators.iterator().next().nextOrNull().currentNodeId(), equalTo("node1"));

        shardIterators = operationRouting.searchShards(project, new String[] { "test" }, null, "_shards:0|_prefer_nodes:node1,node2");
        assertThat(shardIterators.size(), equalTo(1));
        Iterator<ShardIterator> iterator = shardIterators.iterator();
        final ShardIterator it = iterator.next();
        assertThat(it.shardId().id(), equalTo(0));
        final String firstNodeId = it.nextOrNull().currentNodeId();
        assertThat(firstNodeId, anyOf(equalTo("node1"), equalTo("node2")));
        if ("node1".equals(firstNodeId)) {
            assertThat(it.nextOrNull().currentNodeId(), equalTo("node2"));
        } else {
            assertThat(it.nextOrNull().currentNodeId(), equalTo("node1"));
        }
    }

    public void testReplicaShardPreferenceIters() throws Exception {
        OperationRouting operationRouting = new OperationRouting(
            Settings.EMPTY,
            new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)
        );

        ProjectId projectId = randomProjectIdOrDefault();
        Metadata metadata = Metadata.builder()
            .put(
                ProjectMetadata.builder(projectId)
                    .put(IndexMetadata.builder("test").settings(settings(IndexVersion.current())).numberOfShards(2).numberOfReplicas(2))
            )
            .build();

        GlobalRoutingTable routingTable = GlobalRoutingTableTestHelper.buildRoutingTable(metadata, RoutingTable.Builder::addAsNew);

        final ClusterState clusterState = ClusterState.builder(ClusterName.DEFAULT)
            .metadata(metadata)
            .routingTable(routingTable)
            .nodes(DiscoveryNodes.builder().add(newNode("node1")).add(newNode("node2")).add(newNode("node3")).localNodeId("node1"))
            .build();
        ProjectState project = clusterState.projectState(projectId);

        String[] removedPreferences = { "_primary", "_primary_first", "_replica", "_replica_first" };
        for (String pref : removedPreferences) {
            expectThrows(IllegalArgumentException.class, () -> operationRouting.searchShards(project, new String[] { "test" }, null, pref));
        }
    }

}
