/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.common.compress;

import org.apache.lucene.tests.util.LineFileDocs;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.test.ESTestCase;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Random;
import java.util.zip.ZipException;

import static org.hamcrest.Matchers.equalTo;

/**
 * Test streaming compression (e.g. used for recovery)
 */
public class DeflateCompressTests extends ESTestCase {

    private final Compressor compressor = new DeflateCompressor();

    public void testRandom() throws IOException {
        Random r = random();
        for (int i = 0; i < 10; i++) {
            byte bytes[] = new byte[TestUtil.nextInt(r, 1, 100000)];
            r.nextBytes(bytes);
            doTest(bytes);
        }
    }

    public void testRandomThreads() throws Exception {
        startInParallel(randomIntBetween(2, 6), tid -> {
            try {
                for (int i = 0; i < 10; i++) {
                    byte[] bytes = new byte[randomIntBetween(1, 100000)];
                    randomBytesBetween(bytes, Byte.MIN_VALUE, Byte.MAX_VALUE);
                    doTest(bytes);
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    public void testLineDocs() throws IOException {
        Random r = random();
        LineFileDocs lineFileDocs = new LineFileDocs(r);
        for (int i = 0; i < 10; i++) {
            int numDocs = TestUtil.nextInt(r, 1, 200);
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            for (int j = 0; j < numDocs; j++) {
                String s = lineFileDocs.nextDoc().get("body");
                bos.write(s.getBytes(StandardCharsets.UTF_8));
            }
            doTest(bos.toByteArray());
        }
        lineFileDocs.close();
    }

    public void testLineDocsThreads() throws Exception {
        int threadCount = randomIntBetween(2, 6);
        startInParallel(threadCount, tid -> {
            try {
                LineFileDocs lineFileDocs = new LineFileDocs(random());
                for (int i = 0; i < 10; i++) {
                    int numDocs = randomIntBetween(1, 200);
                    ByteArrayOutputStream bos = new ByteArrayOutputStream();
                    for (int j = 0; j < numDocs; j++) {
                        String s = lineFileDocs.nextDoc().get("body");
                        bos.write(s.getBytes(StandardCharsets.UTF_8));
                    }
                    doTest(bos.toByteArray());
                }
                lineFileDocs.close();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    public void testRepetitionsL() throws IOException {
        Random r = random();
        for (int i = 0; i < 10; i++) {
            int numLongs = TestUtil.nextInt(r, 1, 10000);
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            long theValue = r.nextLong();
            for (int j = 0; j < numLongs; j++) {
                if (r.nextInt(10) == 0) {
                    theValue = r.nextLong();
                }
                bos.write((byte) (theValue >>> 56));
                bos.write((byte) (theValue >>> 48));
                bos.write((byte) (theValue >>> 40));
                bos.write((byte) (theValue >>> 32));
                bos.write((byte) (theValue >>> 24));
                bos.write((byte) (theValue >>> 16));
                bos.write((byte) (theValue >>> 8));
                bos.write((byte) theValue);
            }
            doTest(bos.toByteArray());
        }
    }

    public void testRepetitionsLThreads() throws Exception {
        int threadCount = randomIntBetween(2, 6);
        startInParallel(threadCount, tid -> {
            try {
                for (int i = 0; i < 10; i++) {
                    int numLongs = randomIntBetween(1, 10000);
                    ByteArrayOutputStream bos = new ByteArrayOutputStream();
                    long theValue = randomLong();
                    for (int j = 0; j < numLongs; j++) {
                        if (randomInt(10) == 0) {
                            theValue = randomLong();
                        }
                        bos.write((byte) (theValue >>> 56));
                        bos.write((byte) (theValue >>> 48));
                        bos.write((byte) (theValue >>> 40));
                        bos.write((byte) (theValue >>> 32));
                        bos.write((byte) (theValue >>> 24));
                        bos.write((byte) (theValue >>> 16));
                        bos.write((byte) (theValue >>> 8));
                        bos.write((byte) theValue);
                    }
                    doTest(bos.toByteArray());
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    public void testRepetitionsI() throws IOException {
        Random r = random();
        for (int i = 0; i < 10; i++) {
            int numInts = TestUtil.nextInt(r, 1, 20000);
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            int theValue = r.nextInt();
            for (int j = 0; j < numInts; j++) {
                if (r.nextInt(10) == 0) {
                    theValue = r.nextInt();
                }
                bos.write((byte) (theValue >>> 24));
                bos.write((byte) (theValue >>> 16));
                bos.write((byte) (theValue >>> 8));
                bos.write((byte) theValue);
            }
            doTest(bos.toByteArray());
        }
    }

    public void testRepetitionsIThreads() throws Exception {
        int threadCount = randomIntBetween(2, 6);
        startInParallel(threadCount, tid -> {
            try {
                for (int i = 0; i < 10; i++) {
                    int numInts = randomIntBetween(1, 20000);
                    ByteArrayOutputStream bos = new ByteArrayOutputStream();
                    int theValue = randomInt();
                    for (int j = 0; j < numInts; j++) {
                        if (randomInt(10) == 0) {
                            theValue = randomInt();
                        }
                        bos.write((byte) (theValue >>> 24));
                        bos.write((byte) (theValue >>> 16));
                        bos.write((byte) (theValue >>> 8));
                        bos.write((byte) theValue);
                    }
                    doTest(bos.toByteArray());
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    public void testRepetitionsS() throws IOException {
        Random r = random();
        for (int i = 0; i < 10; i++) {
            int numShorts = TestUtil.nextInt(r, 1, 40000);
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            short theValue = (short) r.nextInt(65535);
            for (int j = 0; j < numShorts; j++) {
                if (r.nextInt(10) == 0) {
                    theValue = (short) r.nextInt(65535);
                }
                bos.write((byte) (theValue >>> 8));
                bos.write((byte) theValue);
            }
            doTest(bos.toByteArray());
        }
    }

    public void testMixed() throws IOException {
        Random r = random();
        LineFileDocs lineFileDocs = new LineFileDocs(r);
        for (int i = 0; i < 2; ++i) {
            ByteArrayOutputStream bos = new ByteArrayOutputStream();
            int prevInt = r.nextInt();
            long prevLong = r.nextLong();
            while (bos.size() < 400000) {
                switch (r.nextInt(4)) {
                    case 0 -> addInt(r, prevInt, bos);
                    case 1 -> addLong(r, prevLong, bos);
                    case 2 -> addString(lineFileDocs, bos);
                    case 3 -> addBytes(r, bos);
                    default -> throw new IllegalStateException("Random is broken");
                }
            }
            doTest(bos.toByteArray());
        }
    }

    private void addLong(Random r, long prev, ByteArrayOutputStream bos) {
        long theValue = prev;
        if (r.nextInt(10) != 0) {
            theValue = r.nextLong();
        }
        bos.write((byte) (theValue >>> 56));
        bos.write((byte) (theValue >>> 48));
        bos.write((byte) (theValue >>> 40));
        bos.write((byte) (theValue >>> 32));
        bos.write((byte) (theValue >>> 24));
        bos.write((byte) (theValue >>> 16));
        bos.write((byte) (theValue >>> 8));
        bos.write((byte) theValue);
    }

    private void addInt(Random r, int prev, ByteArrayOutputStream bos) {
        int theValue = prev;
        if (r.nextInt(10) != 0) {
            theValue = r.nextInt();
        }
        bos.write((byte) (theValue >>> 24));
        bos.write((byte) (theValue >>> 16));
        bos.write((byte) (theValue >>> 8));
        bos.write((byte) theValue);
    }

    private void addString(LineFileDocs lineFileDocs, ByteArrayOutputStream bos) throws IOException {
        String s = lineFileDocs.nextDoc().get("body");
        bos.write(s.getBytes(StandardCharsets.UTF_8));
    }

    private void addBytes(Random r, ByteArrayOutputStream bos) throws IOException {
        byte bytes[] = new byte[TestUtil.nextInt(r, 1, 10000)];
        r.nextBytes(bytes);
        bos.write(bytes);
    }

    public void testRepetitionsSThreads() throws Exception {
        int threadCount = randomIntBetween(2, 6);
        startInParallel(threadCount, tid -> {
            try {
                for (int i = 0; i < 10; i++) {
                    int numShorts = randomIntBetween(1, 40000);
                    ByteArrayOutputStream bos = new ByteArrayOutputStream();
                    short theValue = (short) randomInt(65535);
                    for (int j = 0; j < numShorts; j++) {
                        if (randomInt(10) == 0) {
                            theValue = (short) randomInt(65535);
                        }
                        bos.write((byte) (theValue >>> 8));
                        bos.write((byte) theValue);
                    }
                    doTest(bos.toByteArray());
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        });
    }

    public void testCompressUncompressWithCorruptions() throws Exception {
        final Random r = random();
        for (int i = 0; i < 10; i++) {
            byte[] bytes = new byte[TestUtil.nextInt(r, 1, 100000)];
            r.nextBytes(bytes);
            final var offset = between(0, bytes.length - 1);
            final var length = between(0, bytes.length - offset);
            final var original = new BytesArray(bytes, offset, length);
            final var compressed = compressor.compress(original);

            if (randomBoolean()) {
                var corruptIndex = between(0, compressed.length() - 1);
                BytesRef bytesRef;
                final var iterator = compressed.iterator();
                while ((bytesRef = iterator.next()) != null) {
                    if (corruptIndex < bytesRef.length) {
                        bytesRef.bytes[bytesRef.offset + corruptIndex] = randomValueOtherThan(
                            bytesRef.bytes[bytesRef.offset + corruptIndex],
                            () -> (byte) (r.nextInt() & 0xff)
                        );
                        break;
                    } else {
                        corruptIndex -= bytesRef.length;
                    }
                }
                try {
                    compressor.uncompress(compressed);
                } catch (ZipException e) {
                    // ok
                }
            } else {
                var uncompressed = compressor.uncompress(compressed);
                assertEquals(original, uncompressed);
            }
        }
    }

    public void testUncompressTooShort() {
        BytesReference bytes = BytesReference.fromByteBuffer(ByteBuffer.wrap(new byte[] { 0x0, 0x0, 0x0 }));
        var e = expectThrows(IOException.class, () -> compressor.uncompress(bytes));
        assertThat(e.getMessage(), equalTo("Input bytes length 3 is less than DEFLATE header size 4"));
    }

    private void doTest(byte bytes[]) throws IOException {
        InputStream rawIn = new ByteArrayInputStream(bytes);
        Compressor c = compressor;

        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        final Random r = random();
        int bufferSize = r.nextBoolean() ? 65535 : TestUtil.nextInt(random(), 1, 70000);
        int prepadding = r.nextInt(70000);
        int postpadding = r.nextInt(70000);
        byte[] buffer = new byte[prepadding + bufferSize + postpadding];
        int len;
        try (OutputStream os = c.threadLocalOutputStream(bos)) {
            r.nextBytes(buffer); // fill block completely with junk
            while ((len = rawIn.read(buffer, prepadding, bufferSize)) != -1) {
                os.write(buffer, prepadding, len);
            }
        }
        rawIn.close();

        // now we have compressed byte array
        InputStream in = c.threadLocalInputStream(new ByteArrayInputStream(bos.toByteArray()));

        // randomize constants again
        bufferSize = r.nextBoolean() ? 65535 : TestUtil.nextInt(random(), 1, 70000);
        prepadding = r.nextInt(70000);
        postpadding = r.nextInt(70000);
        buffer = new byte[prepadding + bufferSize + postpadding];
        r.nextBytes(buffer); // fill block completely with junk

        ByteArrayOutputStream uncompressedOut = new ByteArrayOutputStream();
        while ((len = in.read(buffer, prepadding, bufferSize)) != -1) {
            uncompressedOut.write(buffer, prepadding, len);
        }
        uncompressedOut.close();

        assertArrayEquals(bytes, uncompressedOut.toByteArray());
    }
}
