/*
 * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package com.sun.crypto.provider;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.*;
import java.security.interfaces.ECKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.XECKey;
import java.security.interfaces.XECPublicKey;
import java.security.spec.*;
import java.util.Arrays;
import java.util.Objects;
import javax.crypto.*;
import javax.crypto.spec.SecretKeySpec;
import javax.crypto.spec.HKDFParameterSpec;

import sun.security.jca.JCAUtil;
import sun.security.util.*;

import jdk.internal.access.SharedSecrets;

// Implementing DHKEM defined inside https://www.rfc-editor.org/rfc/rfc9180.html,
// without the AuthEncap and AuthDecap functions
public class DHKEM implements KEMSpi {

    private static final byte[] KEM = new byte[]
            {'K', 'E', 'M'};
    private static final byte[] EAE_PRK = new byte[]
            {'e', 'a', 'e', '_', 'p', 'r', 'k'};
    private static final byte[] SHARED_SECRET = new byte[]
            {'s', 'h', 'a', 'r', 'e', 'd', '_', 's', 'e', 'c', 'r', 'e', 't'};
    private static final byte[] DKP_PRK = new byte[]
            {'d', 'k', 'p', '_', 'p', 'r', 'k'};
    private static final byte[] CANDIDATE = new byte[]
            {'c', 'a', 'n', 'd', 'i', 'd', 'a', 't', 'e'};
    private static final byte[] SK = new byte[]
            {'s', 'k'};
    private static final byte[] HPKE_V1 = new byte[]
            {'H', 'P', 'K', 'E', '-', 'v', '1'};
    private static final byte[] EMPTY = new byte[0];

    private record Handler(Params params, SecureRandom secureRandom,
                           PrivateKey skR, PublicKey pkR)
                implements EncapsulatorSpi, DecapsulatorSpi {

        @Override
        public KEM.Encapsulated engineEncapsulate(int from, int to, String algorithm) {
            Objects.checkFromToIndex(from, to, params.Nsecret);
            Objects.requireNonNull(algorithm, "null algorithm");
            KeyPair kpE = params.generateKeyPair(secureRandom);
            PrivateKey skE = kpE.getPrivate();
            PublicKey pkE = kpE.getPublic();
            byte[] pkEm = params.SerializePublicKey(pkE);
            byte[] pkRm = params.SerializePublicKey(pkR);
            byte[] kem_context = concat(pkEm, pkRm);
            try {
                byte[] dh = params.DH(skE, pkR);
                byte[] key = params.ExtractAndExpand(dh, kem_context);
                return new KEM.Encapsulated(
                        new SecretKeySpec(key, from, to - from, algorithm),
                        pkEm, null);
            } catch (Exception e) {
                throw new ProviderException("internal error", e);
            }
        }

        @Override
        public SecretKey engineDecapsulate(byte[] encapsulation,
                int from, int to, String algorithm) throws DecapsulateException {
            Objects.checkFromToIndex(from, to, params.Nsecret);
            Objects.requireNonNull(algorithm, "null algorithm");
            Objects.requireNonNull(encapsulation, "null encapsulation");
            if (encapsulation.length != params.Npk) {
                throw new DecapsulateException("incorrect encapsulation size");
            }
            try {
                PublicKey pkE = params.DeserializePublicKey(encapsulation);
                byte[] dh = params.DH(skR, pkE);
                byte[] pkRm = params.SerializePublicKey(pkR);
                byte[] kem_context = concat(encapsulation, pkRm);
                byte[] key = params.ExtractAndExpand(dh, kem_context);
                return new SecretKeySpec(key, from, to - from, algorithm);
            } catch (IOException | InvalidKeyException e) {
                throw new DecapsulateException("Cannot decapsulate", e);
            } catch (Exception e) {
                throw new ProviderException("internal error", e);
            }
        }

        @Override
        public int engineSecretSize() {
            return params.Nsecret;
        }

        @Override
        public int engineEncapsulationSize() {
            return params.Npk;
        }
    }

    // Not really a random. For KAT test only. It generates key pair from ikm.
    public static class RFC9180DeriveKeyPairSR extends SecureRandom {

        static final long serialVersionUID = 0L;

        private final byte[] ikm;

        public RFC9180DeriveKeyPairSR(byte[] ikm) {
            super(null, null); // lightest constructor
            this.ikm = ikm;
        }

        public KeyPair derive(Params params) {
            try {
                return params.deriveKeyPair(ikm);
            } catch (Exception e) {
                throw new UnsupportedOperationException(e);
            }
        }

        public KeyPair derive(int kem_id) {
            Params params = Arrays.stream(Params.values())
                    .filter(p -> p.kem_id == kem_id)
                    .findFirst()
                    .orElseThrow();
            return derive(params);
        }
    }

    private enum Params {

        P256(0x10, 32, 32, 2 * 32 + 1,
                "ECDH", "EC", CurveDB.P_256, "HKDF-SHA256"),

        P384(0x11, 48, 48, 2 * 48 + 1,
                "ECDH", "EC", CurveDB.P_384, "HKDF-SHA384"),

        P521(0x12, 64, 66, 2 * 66 + 1,
                "ECDH", "EC", CurveDB.P_521, "HKDF-SHA512"),

        X25519(0x20, 32, 32, 32,
                "XDH", "XDH", NamedParameterSpec.X25519, "HKDF-SHA256"),

        X448(0x21, 64, 56, 56,
                "XDH", "XDH", NamedParameterSpec.X448, "HKDF-SHA512"),
        ;

        private final int kem_id;
        private final int Nsecret;
        private final int Nsk;
        private final int Npk;
        private final String kaAlgorithm;
        private final String keyAlgorithm;
        private final AlgorithmParameterSpec spec;
        private final String hkdfAlgorithm;

        private final byte[] suiteId;

        Params(int kem_id, int Nsecret, int Nsk, int Npk,
                String kaAlgorithm, String keyAlgorithm, AlgorithmParameterSpec spec,
                String hkdfAlgorithm) {
            this.kem_id = kem_id;
            this.spec = spec;
            this.Nsecret = Nsecret;
            this.Nsk = Nsk;
            this.Npk = Npk;
            this.kaAlgorithm = kaAlgorithm;
            this.keyAlgorithm = keyAlgorithm;
            this.hkdfAlgorithm = hkdfAlgorithm;
            suiteId = concat(KEM, I2OSP(kem_id, 2));
        }

        private boolean isEC() {
            return this == P256 || this == P384 || this == P521;
        }

        private KeyPair generateKeyPair(SecureRandom sr) {
            if (sr instanceof RFC9180DeriveKeyPairSR r9) {
                return r9.derive(this);
            }
            try {
                KeyPairGenerator g = KeyPairGenerator.getInstance(keyAlgorithm);
                g.initialize(spec, sr);
                return g.generateKeyPair();
            } catch (Exception e) {
                throw new ProviderException("internal error", e);
            }
        }

        private byte[] SerializePublicKey(PublicKey k) {
            if (isEC()) {
                ECPoint w = ((ECPublicKey) k).getW();
                return ECUtil.encodePoint(w, ((NamedCurve) spec).getCurve());
            } else {
                byte[] uArray = ((XECPublicKey) k).getU().toByteArray();
                ArrayUtil.reverse(uArray);
                return Arrays.copyOf(uArray, Npk);
            }
        }

        private PublicKey DeserializePublicKey(byte[] data)
                throws IOException, NoSuchAlgorithmException, InvalidKeySpecException {
            KeySpec keySpec;
            if (isEC()) {
                NamedCurve curve = (NamedCurve) this.spec;
                keySpec = new ECPublicKeySpec(
                        ECUtil.decodePoint(data, curve.getCurve()), curve);
            } else {
                data = data.clone();
                ArrayUtil.reverse(data);
                keySpec = new XECPublicKeySpec(
                        this.spec, new BigInteger(1, data));
            }
            return KeyFactory.getInstance(keyAlgorithm).generatePublic(keySpec);
        }

        private byte[] DH(PrivateKey skE, PublicKey pkR)
                throws NoSuchAlgorithmException, InvalidKeyException {
            KeyAgreement ka = KeyAgreement.getInstance(kaAlgorithm);
            ka.init(skE);
            ka.doPhase(pkR, true);
            return ka.generateSecret();
        }

        private byte[] ExtractAndExpand(byte[] dh, byte[] kem_context)
                throws NoSuchAlgorithmException, InvalidKeyException {
            KDF hkdf = KDF.getInstance(hkdfAlgorithm);
            SecretKey eae_prk = LabeledExtract(hkdf, suiteId, EAE_PRK, dh);
            try {
                return LabeledExpand(hkdf, suiteId, eae_prk, SHARED_SECRET,
                        kem_context, Nsecret);
            } finally {
                if (eae_prk instanceof SecretKeySpec s) {
                    SharedSecrets.getJavaxCryptoSpecAccess()
                            .clearSecretKeySpec(s);
                }
            }
        }

        private PublicKey getPublicKey(PrivateKey sk)
                throws InvalidKeyException {
            if (!(sk instanceof InternalPrivateKey)) {
                try {
                    KeyFactory kf = KeyFactory.getInstance(keyAlgorithm, "SunEC");
                    sk = (PrivateKey) kf.translateKey(sk);
                } catch (Exception e) {
                    throw new InvalidKeyException("Error translating key", e);
                }
            }
            if (sk instanceof InternalPrivateKey ik) {
                try {
                    return ik.calculatePublicKey();
                } catch (UnsupportedOperationException e) {
                    throw new InvalidKeyException("Error retrieving key", e);
                }
            } else {
                // Should not happen, unless SunEC goes wrong
                throw new ProviderException("Unknown key");
            }
        }

        // For KAT tests only. See RFC9180DeriveKeyPairSR.
        public KeyPair deriveKeyPair(byte[] ikm) throws Exception {
            KDF hkdf = KDF.getInstance(hkdfAlgorithm);
            SecretKey dkp_prk = LabeledExtract(hkdf, suiteId, DKP_PRK, ikm);
            try {
                if (isEC()) {
                    NamedCurve curve = (NamedCurve) spec;
                    BigInteger sk = BigInteger.ZERO;
                    int counter = 0;
                    while (sk.signum() == 0 ||
                            sk.compareTo(curve.getOrder()) >= 0) {
                        if (counter > 255) {
                            throw new RuntimeException();
                        }
                        byte[] bytes = LabeledExpand(hkdf, suiteId, dkp_prk,
                                CANDIDATE, I2OSP(counter, 1), Nsk);
                        // bitmask is defined to be 0xFF for P-256 and P-384,
                        // and 0x01 for P-521
                        if (this == Params.P521) {
                            bytes[0] = (byte) (bytes[0] & 0x01);
                        }
                        sk = new BigInteger(1, (bytes));
                        counter = counter + 1;
                    }
                    PrivateKey k = DeserializePrivateKey(sk.toByteArray());
                    return new KeyPair(getPublicKey(k), k);
                } else {
                    byte[] sk = LabeledExpand(hkdf, suiteId, dkp_prk, SK, EMPTY,
                            Nsk);
                    PrivateKey k = DeserializePrivateKey(sk);
                    return new KeyPair(getPublicKey(k), k);
                }
            } finally {
                if (dkp_prk instanceof SecretKeySpec s) {
                    SharedSecrets.getJavaxCryptoSpecAccess()
                            .clearSecretKeySpec(s);
                }
            }
        }

        private PrivateKey DeserializePrivateKey(byte[] data) throws Exception {
            KeySpec keySpec = isEC()
                    ? new ECPrivateKeySpec(new BigInteger(1, (data)), (NamedCurve) spec)
                    : new XECPrivateKeySpec(spec, data);
            return KeyFactory.getInstance(keyAlgorithm).generatePrivate(keySpec);
        }
    }

    private static SecureRandom getSecureRandom(SecureRandom userSR) {
        return userSR != null ? userSR : JCAUtil.getSecureRandom();
    }

    @Override
    public EncapsulatorSpi engineNewEncapsulator(
            PublicKey pk, AlgorithmParameterSpec spec, SecureRandom secureRandom)
            throws InvalidAlgorithmParameterException, InvalidKeyException {
        if (pk == null) {
            throw new InvalidKeyException("input key is null");
        }
        if (spec != null) {
            throw new InvalidAlgorithmParameterException("no spec needed");
        }
        Params params = paramsFromKey(pk);
        return new Handler(params, getSecureRandom(secureRandom), null, pk);
    }

    @Override
    public DecapsulatorSpi engineNewDecapsulator(PrivateKey sk, AlgorithmParameterSpec spec)
            throws InvalidAlgorithmParameterException, InvalidKeyException {
        if (sk == null) {
            throw new InvalidKeyException("input key is null");
        }
        if (spec != null) {
            throw new InvalidAlgorithmParameterException("no spec needed");
        }
        Params params = paramsFromKey(sk);
        return new Handler(params, null, sk, params.getPublicKey(sk));
    }

    private Params paramsFromKey(Key k) throws InvalidKeyException {
        if (k instanceof ECKey eckey) {
            if (ECUtil.equals(eckey.getParams(), CurveDB.P_256)) {
                return Params.P256;
            } else if (ECUtil.equals(eckey.getParams(), CurveDB.P_384)) {
                return Params.P384;
            } else if (ECUtil.equals(eckey.getParams(), CurveDB.P_521)) {
                return Params.P521;
            }
        } else if (k instanceof XECKey xkey
                && xkey.getParams() instanceof NamedParameterSpec ns) {
            if (ns.getName().equalsIgnoreCase("X25519")) {
                return Params.X25519;
            } else if (ns.getName().equalsIgnoreCase("X448")) {
                return Params.X448;
            }
        }
        throw new InvalidKeyException("Unsupported key");
    }

    private static byte[] concat(byte[]... inputs) {
        ByteArrayOutputStream o = new ByteArrayOutputStream();
        Arrays.stream(inputs).forEach(o::writeBytes);
        return o.toByteArray();
    }

    private static byte[] I2OSP(int n, int w) {
        assert n < 256;
        assert w == 1 || w == 2;
        if (w == 1) {
            return new byte[] { (byte) n };
        } else {
            return new byte[] { (byte) (n >> 8), (byte) n };
        }
    }

    private static SecretKey LabeledExtract(KDF hkdf, byte[] suite_id,
            byte[] label, byte[] ikm) throws InvalidKeyException {
        SecretKeySpec s = new SecretKeySpec(concat(HPKE_V1, suite_id, label,
                ikm), "IKM");
        try {
            HKDFParameterSpec spec =
                    HKDFParameterSpec.ofExtract().addIKM(s).extractOnly();
            return hkdf.deriveKey("Generic", spec);
        } catch (InvalidAlgorithmParameterException |
                 NoSuchAlgorithmException e) {
            throw new InvalidKeyException(e.getMessage(), e);
        } finally {
            SharedSecrets.getJavaxCryptoSpecAccess().clearSecretKeySpec(s);
        }
    }

    private static byte[] LabeledExpand(KDF hkdf, byte[] suite_id,
            SecretKey prk, byte[] label, byte[] info, int L)
            throws InvalidKeyException {
        byte[] labeled_info = concat(I2OSP(L, 2), HPKE_V1, suite_id, label,
                info);
        try {
            return hkdf.deriveData(HKDFParameterSpec.expandOnly(
                    prk, labeled_info, L));
        } catch (InvalidAlgorithmParameterException iape) {
            throw new InvalidKeyException(iape.getMessage(), iape);
        }
    }
}
