// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.mlkem;

import org.bouncycastle.util.Arrays;
import java.security.SecureRandom;

class MLKEMEngine
{
    private SecureRandom random;
    private final MLKEMIndCpa indCpa;
    public static final int KyberN = 256;
    public static final int KyberQ = 3329;
    public static final int KyberQinv = 62209;
    public static final int KyberSymBytes = 32;
    private static final int KyberSharedSecretBytes = 32;
    public static final int KyberPolyBytes = 384;
    private static final int KyberEta2 = 2;
    private static final int KyberIndCpaMsgBytes = 32;
    private final int KyberK;
    private final int KyberPolyVecBytes;
    private final int KyberPolyCompressedBytes;
    private final int KyberPolyVecCompressedBytes;
    private final int KyberEta1;
    private final int KyberIndCpaPublicKeyBytes;
    private final int KyberIndCpaSecretKeyBytes;
    private final int KyberIndCpaBytes;
    private final int KyberPublicKeyBytes;
    private final int KyberSecretKeyBytes;
    private final int KyberCipherTextBytes;
    private final int CryptoBytes;
    private final int CryptoSecretKeyBytes;
    private final int CryptoPublicKeyBytes;
    private final int CryptoCipherTextBytes;
    private final int sessionKeyLength;
    private final Symmetric symmetric;
    
    public Symmetric getSymmetric() {
        return this.symmetric;
    }
    
    public static int getKyberEta2() {
        return 2;
    }
    
    public static int getKyberIndCpaMsgBytes() {
        return 32;
    }
    
    public int getCryptoCipherTextBytes() {
        return this.CryptoCipherTextBytes;
    }
    
    public int getCryptoPublicKeyBytes() {
        return this.CryptoPublicKeyBytes;
    }
    
    public int getCryptoSecretKeyBytes() {
        return this.CryptoSecretKeyBytes;
    }
    
    public int getCryptoBytes() {
        return this.CryptoBytes;
    }
    
    public int getKyberCipherTextBytes() {
        return this.KyberCipherTextBytes;
    }
    
    public int getKyberSecretKeyBytes() {
        return this.KyberSecretKeyBytes;
    }
    
    public int getKyberIndCpaPublicKeyBytes() {
        return this.KyberIndCpaPublicKeyBytes;
    }
    
    public int getKyberIndCpaSecretKeyBytes() {
        return this.KyberIndCpaSecretKeyBytes;
    }
    
    public int getKyberIndCpaBytes() {
        return this.KyberIndCpaBytes;
    }
    
    public int getKyberPublicKeyBytes() {
        return this.KyberPublicKeyBytes;
    }
    
    public int getKyberPolyCompressedBytes() {
        return this.KyberPolyCompressedBytes;
    }
    
    public int getKyberK() {
        return this.KyberK;
    }
    
    public int getKyberPolyVecBytes() {
        return this.KyberPolyVecBytes;
    }
    
    public int getKyberPolyVecCompressedBytes() {
        return this.KyberPolyVecCompressedBytes;
    }
    
    public int getKyberEta1() {
        return this.KyberEta1;
    }
    
    public MLKEMEngine(final int n) {
        switch (this.KyberK = n) {
            case 2: {
                this.KyberEta1 = 3;
                this.KyberPolyCompressedBytes = 128;
                this.KyberPolyVecCompressedBytes = n * 320;
                this.sessionKeyLength = 32;
                break;
            }
            case 3: {
                this.KyberEta1 = 2;
                this.KyberPolyCompressedBytes = 128;
                this.KyberPolyVecCompressedBytes = n * 320;
                this.sessionKeyLength = 32;
                break;
            }
            case 4: {
                this.KyberEta1 = 2;
                this.KyberPolyCompressedBytes = 160;
                this.KyberPolyVecCompressedBytes = n * 352;
                this.sessionKeyLength = 32;
                break;
            }
            default: {
                throw new IllegalArgumentException("K: " + n + " is not supported for Crystals Kyber");
            }
        }
        this.KyberPolyVecBytes = n * 384;
        this.KyberIndCpaPublicKeyBytes = this.KyberPolyVecBytes + 32;
        this.KyberIndCpaSecretKeyBytes = this.KyberPolyVecBytes;
        this.KyberIndCpaBytes = this.KyberPolyVecCompressedBytes + this.KyberPolyCompressedBytes;
        this.KyberPublicKeyBytes = this.KyberIndCpaPublicKeyBytes;
        this.KyberSecretKeyBytes = this.KyberIndCpaSecretKeyBytes + this.KyberIndCpaPublicKeyBytes + 64;
        this.KyberCipherTextBytes = this.KyberIndCpaBytes;
        this.CryptoBytes = 32;
        this.CryptoSecretKeyBytes = this.KyberSecretKeyBytes;
        this.CryptoPublicKeyBytes = this.KyberPublicKeyBytes;
        this.CryptoCipherTextBytes = this.KyberCipherTextBytes;
        this.symmetric = new Symmetric.ShakeSymmetric();
        this.indCpa = new MLKEMIndCpa(this);
    }
    
    public void init(final SecureRandom random) {
        this.random = random;
    }
    
    boolean checkModulus(final byte[] array) {
        return PolyVec.checkModulus(this, array) < 0;
    }
    
    public byte[][] generateKemKeyPair() {
        final byte[] bytes = new byte[32];
        final byte[] bytes2 = new byte[32];
        this.random.nextBytes(bytes);
        this.random.nextBytes(bytes2);
        return this.generateKemKeyPairInternal(bytes, bytes2);
    }
    
    public byte[][] generateKemKeyPairInternal(final byte[] array, final byte[] array2) {
        final byte[][] generateKeyPair = this.indCpa.generateKeyPair(array);
        final byte[] array3 = new byte[this.KyberIndCpaSecretKeyBytes];
        System.arraycopy(generateKeyPair[1], 0, array3, 0, this.KyberIndCpaSecretKeyBytes);
        final byte[] array4 = new byte[32];
        this.symmetric.hash_h(array4, generateKeyPair[0], 0);
        final byte[] array5 = new byte[this.KyberIndCpaPublicKeyBytes];
        System.arraycopy(generateKeyPair[0], 0, array5, 0, this.KyberIndCpaPublicKeyBytes);
        return new byte[][] { Arrays.copyOfRange(array5, 0, array5.length - 32), Arrays.copyOfRange(array5, array5.length - 32, array5.length), array3, array4, array2, Arrays.concatenate(array, array2) };
    }
    
    byte[][] kemEncrypt(final MLKEMPublicKeyParameters mlkemPublicKeyParameters, final byte[] array) {
        final byte[] encoded = mlkemPublicKeyParameters.getEncoded();
        final byte[] array2 = new byte[64];
        final byte[] array3 = new byte[64];
        System.arraycopy(array, 0, array2, 0, 32);
        this.symmetric.hash_h(array2, encoded, 32);
        this.symmetric.hash_g(array3, array2);
        final byte[] encrypt = this.indCpa.encrypt(encoded, Arrays.copyOfRange(array2, 0, 32), Arrays.copyOfRange(array3, 32, array3.length));
        final byte[] array4 = new byte[this.sessionKeyLength];
        System.arraycopy(array3, 0, array4, 0, array4.length);
        return new byte[][] { array4, encrypt };
    }
    
    byte[] kemDecrypt(final MLKEMPrivateKeyParameters mlkemPrivateKeyParameters, final byte[] array) {
        final byte[] encoded = mlkemPrivateKeyParameters.getEncoded();
        final byte[] array2 = new byte[64];
        final byte[] array3 = new byte[64];
        final byte[] copyOfRange = Arrays.copyOfRange(encoded, this.KyberIndCpaSecretKeyBytes, encoded.length);
        System.arraycopy(this.indCpa.decrypt(encoded, array), 0, array2, 0, 32);
        System.arraycopy(encoded, this.KyberSecretKeyBytes - 64, array2, 32, 32);
        this.symmetric.hash_g(array3, array2);
        final byte[] array4 = new byte[32 + this.KyberCipherTextBytes];
        System.arraycopy(encoded, this.KyberSecretKeyBytes - 32, array4, 0, 32);
        System.arraycopy(array, 0, array4, 32, this.KyberCipherTextBytes);
        this.symmetric.kdf(array4, array4);
        this.cmov(array3, array4, 32, this.constantTimeZeroOnEqual(array, this.indCpa.encrypt(copyOfRange, Arrays.copyOfRange(array2, 0, 32), Arrays.copyOfRange(array3, 32, array3.length))));
        return Arrays.copyOfRange(array3, 0, this.sessionKeyLength);
    }
    
    private void cmov(final byte[] array, final byte[] array2, final int n, final int n2) {
        final int n3 = 0 - n2 >> 24;
        for (int i = 0; i != n; ++i) {
            array[i] = (byte)((array2[i] & n3) | (array[i] & ~n3));
        }
    }
    
    private int constantTimeZeroOnEqual(final byte[] array, final byte[] array2) {
        int n = array2.length ^ array.length;
        for (int i = 0; i != array2.length; ++i) {
            n |= (array[i] ^ array2[i]);
        }
        return n & 0xFF;
    }
}
