// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.mlkem;

import org.bouncycastle.util.Arrays;

class MLKEMIndCpa
{
    private final MLKEMEngine engine;
    private final int kyberK;
    private final int indCpaPublicKeyBytes;
    private final int polyVecBytes;
    private final int indCpaBytes;
    private final int polyVecCompressedBytes;
    private final int polyCompressedBytes;
    private Symmetric symmetric;
    public final int KyberGenerateMatrixNBlocks;
    
    public MLKEMIndCpa(final MLKEMEngine engine) {
        this.engine = engine;
        this.kyberK = engine.getKyberK();
        this.indCpaPublicKeyBytes = engine.getKyberPublicKeyBytes();
        this.polyVecBytes = engine.getKyberPolyVecBytes();
        this.indCpaBytes = engine.getKyberIndCpaBytes();
        this.polyVecCompressedBytes = engine.getKyberPolyVecCompressedBytes();
        this.polyCompressedBytes = engine.getKyberPolyCompressedBytes();
        this.symmetric = engine.getSymmetric();
        this.KyberGenerateMatrixNBlocks = (472 + this.symmetric.xofBlockBytes) / this.symmetric.xofBlockBytes;
    }
    
    byte[][] generateKeyPair(final byte[] array) {
        final PolyVec polyVec = new PolyVec(this.engine);
        final PolyVec polyVec2 = new PolyVec(this.engine);
        final PolyVec polyVec3 = new PolyVec(this.engine);
        final byte[] array2 = new byte[64];
        this.symmetric.hash_g(array2, Arrays.append(array, (byte)this.kyberK));
        final byte[] array3 = new byte[32];
        final byte[] array4 = new byte[32];
        System.arraycopy(array2, 0, array3, 0, 32);
        System.arraycopy(array2, 32, array4, 0, 32);
        byte b = 0;
        final PolyVec[] array5 = new PolyVec[this.kyberK];
        for (int i = 0; i < this.kyberK; ++i) {
            array5[i] = new PolyVec(this.engine);
        }
        this.generateMatrix(array5, array3, false);
        for (int j = 0; j < this.kyberK; ++j) {
            polyVec.getVectorIndex(j).getEta1Noise(array4, b);
            ++b;
        }
        for (int k = 0; k < this.kyberK; ++k) {
            polyVec3.getVectorIndex(k).getEta1Noise(array4, b);
            ++b;
        }
        polyVec.polyVecNtt();
        polyVec3.polyVecNtt();
        for (int l = 0; l < this.kyberK; ++l) {
            PolyVec.pointwiseAccountMontgomery(polyVec2.getVectorIndex(l), array5[l], polyVec, this.engine);
            polyVec2.getVectorIndex(l).convertToMont();
        }
        polyVec2.addPoly(polyVec3);
        polyVec2.reducePoly();
        return new byte[][] { this.packPublicKey(polyVec2, array3), this.packSecretKey(polyVec) };
    }
    
    public byte[] encrypt(final byte[] array, final byte[] array2, final byte[] array3) {
        byte b = 0;
        final PolyVec polyVec = new PolyVec(this.engine);
        final PolyVec polyVec2 = new PolyVec(this.engine);
        final PolyVec polyVec3 = new PolyVec(this.engine);
        final PolyVec polyVec4 = new PolyVec(this.engine);
        final PolyVec[] array4 = new PolyVec[this.engine.getKyberK()];
        final Poly poly = new Poly(this.engine);
        final Poly poly2 = new Poly(this.engine);
        final Poly poly3 = new Poly(this.engine);
        final byte[] unpackPublicKey = this.unpackPublicKey(polyVec2, array);
        poly3.fromMsg(array2);
        for (int i = 0; i < this.kyberK; ++i) {
            array4[i] = new PolyVec(this.engine);
        }
        this.generateMatrix(array4, unpackPublicKey, true);
        for (int j = 0; j < this.kyberK; ++j) {
            polyVec.getVectorIndex(j).getEta1Noise(array3, b);
            ++b;
        }
        for (int k = 0; k < this.kyberK; ++k) {
            polyVec3.getVectorIndex(k).getEta2Noise(array3, b);
            ++b;
        }
        poly.getEta2Noise(array3, b);
        polyVec.polyVecNtt();
        for (int l = 0; l < this.kyberK; ++l) {
            PolyVec.pointwiseAccountMontgomery(polyVec4.getVectorIndex(l), array4[l], polyVec, this.engine);
        }
        PolyVec.pointwiseAccountMontgomery(poly2, polyVec2, polyVec, this.engine);
        polyVec4.polyVecInverseNttToMont();
        poly2.polyInverseNttToMont();
        polyVec4.addPoly(polyVec3);
        poly2.addCoeffs(poly);
        poly2.addCoeffs(poly3);
        polyVec4.reducePoly();
        poly2.reduce();
        return this.packCipherText(polyVec4, poly2);
    }
    
    private byte[] packCipherText(final PolyVec polyVec, final Poly poly) {
        final byte[] array = new byte[this.indCpaBytes];
        System.arraycopy(polyVec.compressPolyVec(), 0, array, 0, this.polyVecCompressedBytes);
        System.arraycopy(poly.compressPoly(), 0, array, this.polyVecCompressedBytes, this.polyCompressedBytes);
        return array;
    }
    
    private void unpackCipherText(final PolyVec polyVec, final Poly poly, final byte[] array) {
        polyVec.decompressPolyVec(Arrays.copyOfRange(array, 0, this.engine.getKyberPolyVecCompressedBytes()));
        poly.decompressPoly(Arrays.copyOfRange(array, this.engine.getKyberPolyVecCompressedBytes(), array.length));
    }
    
    public byte[] packPublicKey(final PolyVec polyVec, final byte[] array) {
        final byte[] array2 = new byte[this.indCpaPublicKeyBytes];
        System.arraycopy(polyVec.toBytes(), 0, array2, 0, this.polyVecBytes);
        System.arraycopy(array, 0, array2, this.polyVecBytes, 32);
        return array2;
    }
    
    public byte[] unpackPublicKey(final PolyVec polyVec, final byte[] array) {
        final byte[] array2 = new byte[32];
        polyVec.fromBytes(array);
        System.arraycopy(array, this.polyVecBytes, array2, 0, 32);
        return array2;
    }
    
    public byte[] packSecretKey(final PolyVec polyVec) {
        return polyVec.toBytes();
    }
    
    public void unpackSecretKey(final PolyVec polyVec, final byte[] array) {
        polyVec.fromBytes(array);
    }
    
    public void generateMatrix(final PolyVec[] array, final byte[] array2, final boolean b) {
        final byte[] array3 = new byte[this.KyberGenerateMatrixNBlocks * this.symmetric.xofBlockBytes + 2];
        for (int i = 0; i < this.kyberK; ++i) {
            for (int j = 0; j < this.kyberK; ++j) {
                if (b) {
                    this.symmetric.xofAbsorb(array2, (byte)i, (byte)j);
                }
                else {
                    this.symmetric.xofAbsorb(array2, (byte)j, (byte)i);
                }
                this.symmetric.xofSqueezeBlocks(array3, 0, this.symmetric.xofBlockBytes * this.KyberGenerateMatrixNBlocks);
                for (int n = this.KyberGenerateMatrixNBlocks * this.symmetric.xofBlockBytes, k = rejectionSampling(array[i].getVectorIndex(j), 0, 256, array3, n); k < 256; k += rejectionSampling(array[i].getVectorIndex(j), k, 256 - k, array3, n)) {
                    final int n2 = n % 3;
                    for (int l = 0; l < n2; ++l) {
                        array3[l] = array3[n - n2 + l];
                    }
                    this.symmetric.xofSqueezeBlocks(array3, n2, this.symmetric.xofBlockBytes * 2);
                    n = n2 + this.symmetric.xofBlockBytes;
                }
            }
        }
    }
    
    private static int rejectionSampling(final Poly poly, final int n, final int n2, final byte[] array, final int n3) {
        int n5;
        for (int n4 = n5 = 0; n5 < n2 && n4 + 3 <= n3; ++n5) {
            final short n6 = (short)(((short)(array[n4] & 0xFF) >> 0 | (short)(array[n4 + 1] & 0xFF) << 8) & 0xFFF);
            final short n7 = (short)(((short)(array[n4 + 1] & 0xFF) >> 4 | (short)(array[n4 + 2] & 0xFF) << 4) & 0xFFF);
            n4 += 3;
            if (n6 < 3329) {
                poly.setCoeffIndex(n + n5, n6);
                ++n5;
            }
            if (n5 < n2 && n7 < 3329) {
                poly.setCoeffIndex(n + n5, n7);
            }
        }
        return n5;
    }
    
    public byte[] decrypt(final byte[] array, final byte[] array2) {
        final byte[] array3 = new byte[MLKEMEngine.getKyberIndCpaMsgBytes()];
        final PolyVec polyVec = new PolyVec(this.engine);
        final PolyVec polyVec2 = new PolyVec(this.engine);
        final Poly poly = new Poly(this.engine);
        final Poly poly2 = new Poly(this.engine);
        this.unpackCipherText(polyVec, poly, array2);
        this.unpackSecretKey(polyVec2, array);
        polyVec.polyVecNtt();
        PolyVec.pointwiseAccountMontgomery(poly2, polyVec2, polyVec, this.engine);
        poly2.polyInverseNttToMont();
        poly2.polySubtract(poly);
        poly2.reduce();
        return poly2.toMsg();
    }
}
