/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.search.profile.query;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Matches;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.elasticsearch.search.profile.Timer;

import java.io.IOException;

/**
 * Weight wrapper that will compute how much time it takes to build the
 * {@link Scorer} and then return a {@link Scorer} that is wrapped in
 * order to compute timings as well.
 */
public final class ProfileWeight extends Weight {

    private final Weight subQueryWeight;
    private final QueryProfileBreakdown profile;

    public ProfileWeight(Query query, Weight subQueryWeight, QueryProfileBreakdown profile) {
        super(query);
        this.subQueryWeight = subQueryWeight;
        this.profile = profile;
    }

    @Override
    public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
        final Timer timer = profile.getNewTimer(QueryTimingType.BUILD_SCORER);
        timer.start();
        final ScorerSupplier subQueryScorerSupplier;
        try {
            subQueryScorerSupplier = subQueryWeight.scorerSupplier(context);
        } finally {
            timer.stop();
        }
        if (subQueryScorerSupplier == null) {
            return null;
        }

        final ProfileWeight weight = this;
        return new ScorerSupplier() {

            @Override
            public Scorer get(long loadCost) throws IOException {
                timer.start();
                try {
                    return new ProfileScorer(subQueryScorerSupplier.get(loadCost), profile);
                } finally {
                    timer.stop();
                }
            }

            @Override
            public BulkScorer bulkScorer() throws IOException {
                // We use the default bulk scorer instead of the specialized one. The reason
                // is that Lucene's BulkScorers do everything at once: finding matches,
                // scoring them and calling the collector, so they make it impossible to
                // see where time is spent, which is the purpose of query profiling.
                // The default bulk scorer will pull a scorer and iterate over matches,
                // this might be a significantly different execution path for some queries
                // like disjunctions, but in general this is what is done anyway
                return super.bulkScorer();
            }

            @Override
            public long cost() {
                timer.start();
                try {
                    return subQueryScorerSupplier.cost();
                } finally {
                    timer.stop();
                }
            }

            @Override
            public void setTopLevelScoringClause() throws IOException {
                subQueryScorerSupplier.setTopLevelScoringClause();
            }
        };
    }

    @Override
    public Explanation explain(LeafReaderContext context, int doc) throws IOException {
        return subQueryWeight.explain(context, doc);
    }

    @Override
    public int count(LeafReaderContext context) throws IOException {
        Timer timer = profile.getNewTimer(QueryTimingType.COUNT_WEIGHT);
        timer.start();
        try {
            return subQueryWeight.count(context);
        } finally {
            timer.stop();
        }
    }

    @Override
    public boolean isCacheable(LeafReaderContext ctx) {
        return false;
    }

    public Matches matches(LeafReaderContext context, int doc) throws IOException {
        return subQueryWeight.matches(context, doc);
    }
}
