// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.mlkem;

class Poly
{
    private short[] coeffs;
    private MLKEMEngine engine;
    private int polyCompressedBytes;
    private int eta1;
    private int eta2;
    private Symmetric symmetric;
    
    public Poly(final MLKEMEngine engine) {
        this.coeffs = new short[256];
        this.engine = engine;
        this.polyCompressedBytes = engine.getKyberPolyCompressedBytes();
        this.eta1 = engine.getKyberEta1();
        this.eta2 = MLKEMEngine.getKyberEta2();
        this.symmetric = engine.getSymmetric();
    }
    
    public short getCoeffIndex(final int n) {
        return this.coeffs[n];
    }
    
    public short[] getCoeffs() {
        return this.coeffs;
    }
    
    public void setCoeffIndex(final int n, final short n2) {
        this.coeffs[n] = n2;
    }
    
    public void setCoeffs(final short[] coeffs) {
        this.coeffs = coeffs;
    }
    
    public void polyNtt() {
        this.setCoeffs(Ntt.ntt(this.getCoeffs()));
        this.reduce();
    }
    
    public void polyInverseNttToMont() {
        this.setCoeffs(Ntt.invNtt(this.getCoeffs()));
    }
    
    public void reduce() {
        for (int i = 0; i < 256; ++i) {
            this.setCoeffIndex(i, Reduce.barretReduce(this.getCoeffIndex(i)));
        }
    }
    
    public static void baseMultMontgomery(final Poly poly, final Poly poly2, final Poly poly3) {
        for (int i = 0; i < 64; ++i) {
            Ntt.baseMult(poly, 4 * i, poly2.getCoeffIndex(4 * i), poly2.getCoeffIndex(4 * i + 1), poly3.getCoeffIndex(4 * i), poly3.getCoeffIndex(4 * i + 1), Ntt.nttZetas[64 + i]);
            Ntt.baseMult(poly, 4 * i + 2, poly2.getCoeffIndex(4 * i + 2), poly2.getCoeffIndex(4 * i + 3), poly3.getCoeffIndex(4 * i + 2), poly3.getCoeffIndex(4 * i + 3), (short)(-1 * Ntt.nttZetas[64 + i]));
        }
    }
    
    public void addCoeffs(final Poly poly) {
        for (int i = 0; i < 256; ++i) {
            this.setCoeffIndex(i, (short)(this.getCoeffIndex(i) + poly.getCoeffIndex(i)));
        }
    }
    
    public void convertToMont() {
        for (int i = 0; i < 256; ++i) {
            this.setCoeffIndex(i, Reduce.montgomeryReduce(this.getCoeffIndex(i) * 1353));
        }
    }
    
    public byte[] compressPoly() {
        final byte[] array = new byte[8];
        final byte[] array2 = new byte[this.polyCompressedBytes];
        int n = 0;
        this.conditionalSubQ();
        if (this.polyCompressedBytes == 128) {
            for (int i = 0; i < 32; ++i) {
                for (int j = 0; j < 8; ++j) {
                    int n2 = this.getCoeffIndex(8 * i + j) << 4;
                    n2 += 1665;
                    array[j] = (byte)(n2 * 80635 >> 28 & 0xF);
                }
                array2[n + 0] = (byte)(array[0] | array[1] << 4);
                array2[n + 1] = (byte)(array[2] | array[3] << 4);
                array2[n + 2] = (byte)(array[4] | array[5] << 4);
                array2[n + 3] = (byte)(array[6] | array[7] << 4);
                n += 4;
            }
        }
        else {
            if (this.polyCompressedBytes != 160) {
                throw new RuntimeException("PolyCompressedBytes is neither 128 or 160!");
            }
            for (int k = 0; k < 32; ++k) {
                for (int l = 0; l < 8; ++l) {
                    int n3 = this.getCoeffIndex(8 * k + l) << 5;
                    n3 += 1664;
                    array[l] = (byte)(n3 * 40318 >> 27 & 0x1F);
                }
                array2[n + 0] = (byte)(array[0] >> 0 | array[1] << 5);
                array2[n + 1] = (byte)(array[1] >> 3 | array[2] << 2 | array[3] << 7);
                array2[n + 2] = (byte)(array[3] >> 1 | array[4] << 4);
                array2[n + 3] = (byte)(array[4] >> 4 | array[5] << 1 | array[6] << 6);
                array2[n + 4] = (byte)(array[6] >> 2 | array[7] << 3);
                n += 5;
            }
        }
        return array2;
    }
    
    public void decompressPoly(final byte[] array) {
        int n = 0;
        if (this.engine.getKyberPolyCompressedBytes() == 128) {
            for (int i = 0; i < 128; ++i) {
                this.setCoeffIndex(2 * i + 0, (short)((short)(array[n] & 0xFF & 0xF) * 3329 + 8 >> 4));
                this.setCoeffIndex(2 * i + 1, (short)((short)((array[n] & 0xFF) >> 4) * 3329 + 8 >> 4));
                ++n;
            }
        }
        else {
            if (this.engine.getKyberPolyCompressedBytes() != 160) {
                throw new RuntimeException("PolyCompressedBytes is neither 128 or 160!");
            }
            final byte[] array2 = new byte[8];
            for (int j = 0; j < 32; ++j) {
                array2[0] = (byte)((array[n + 0] & 0xFF) >> 0);
                array2[1] = (byte)((array[n + 0] & 0xFF) >> 5 | (array[n + 1] & 0xFF) << 3);
                array2[2] = (byte)((array[n + 1] & 0xFF) >> 2);
                array2[3] = (byte)((array[n + 1] & 0xFF) >> 7 | (array[n + 2] & 0xFF) << 1);
                array2[4] = (byte)((array[n + 2] & 0xFF) >> 4 | (array[n + 3] & 0xFF) << 4);
                array2[5] = (byte)((array[n + 3] & 0xFF) >> 1);
                array2[6] = (byte)((array[n + 3] & 0xFF) >> 6 | (array[n + 4] & 0xFF) << 2);
                array2[7] = (byte)((array[n + 4] & 0xFF) >> 3);
                n += 5;
                for (int k = 0; k < 8; ++k) {
                    this.setCoeffIndex(8 * j + k, (short)((array2[k] & 0x1F) * 3329 + 16 >> 5));
                }
            }
        }
    }
    
    public byte[] toBytes() {
        this.conditionalSubQ();
        final byte[] array = new byte[384];
        for (int i = 0; i < 128; ++i) {
            final short n = this.coeffs[2 * i + 0];
            final short n2 = this.coeffs[2 * i + 1];
            array[3 * i + 0] = (byte)(n >> 0);
            array[3 * i + 1] = (byte)(n >> 8 | n2 << 4);
            array[3 * i + 2] = (byte)(n2 >> 4);
        }
        return array;
    }
    
    public void fromBytes(final byte[] array) {
        for (int i = 0; i < 128; ++i) {
            final int n = array[3 * i + 0] & 0xFF;
            final int n2 = array[3 * i + 1] & 0xFF;
            final int n3 = array[3 * i + 2] & 0xFF;
            this.coeffs[2 * i + 0] = (short)((n >> 0 | n2 << 8) & 0xFFF);
            this.coeffs[2 * i + 1] = (short)((n2 >> 4 | n3 << 4) & 0xFFF);
        }
    }
    
    public byte[] toMsg() {
        final short n = 832;
        final int n2 = 3329 - n;
        final byte[] array = new byte[MLKEMEngine.getKyberIndCpaMsgBytes()];
        this.conditionalSubQ();
        for (int i = 0; i < 32; ++i) {
            array[i] = 0;
            for (int j = 0; j < 8; ++j) {
                final short coeffIndex = this.getCoeffIndex(8 * i + j);
                final int n3 = (n - coeffIndex & coeffIndex - n2) >>> 31;
                final byte[] array2 = array;
                final int n4 = i;
                array2[n4] |= (byte)(n3 << j);
            }
        }
        return array;
    }
    
    public void fromMsg(final byte[] array) {
        if (array.length != 32) {
            throw new RuntimeException("KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!");
        }
        for (int i = 0; i < 32; ++i) {
            for (int j = 0; j < 8; ++j) {
                this.setCoeffIndex(8 * i + j, (short)((short)(-1 * (short)((array[i] & 0xFF) >> j & 0x1)) & 0x681));
            }
        }
    }
    
    public void conditionalSubQ() {
        for (int i = 0; i < 256; ++i) {
            this.setCoeffIndex(i, Reduce.conditionalSubQ(this.getCoeffIndex(i)));
        }
    }
    
    public void getEta1Noise(final byte[] array, final byte b) {
        final byte[] array2 = new byte[256 * this.eta1 / 4];
        this.symmetric.prf(array2, array, b);
        CBD.mlkemCBD(this, array2, this.eta1);
    }
    
    public void getEta2Noise(final byte[] array, final byte b) {
        final byte[] array2 = new byte[256 * this.eta2 / 4];
        this.symmetric.prf(array2, array, b);
        CBD.mlkemCBD(this, array2, this.eta2);
    }
    
    public void polySubtract(final Poly poly) {
        for (int i = 0; i < 256; ++i) {
            this.setCoeffIndex(i, (short)(poly.getCoeffIndex(i) - this.getCoeffIndex(i)));
        }
    }
    
    @Override
    public String toString() {
        final StringBuilder sb = new StringBuilder();
        sb.append("[");
        for (int i = 0; i < this.coeffs.length; ++i) {
            sb.append(this.coeffs[i]);
            if (i != this.coeffs.length - 1) {
                sb.append(", ");
            }
        }
        sb.append("]");
        return sb.toString();
    }
    
    static int checkModulus(final byte[] array, final int n) {
        int n2 = -1;
        for (int i = 0; i < 128; ++i) {
            final int n3 = array[n + 3 * i + 0] & 0xFF;
            final int n4 = array[n + 3 * i + 1] & 0xFF;
            n2 = (n2 & Reduce.checkModulus((short)((n3 >> 0 | n4 << 8) & 0xFFF)) & Reduce.checkModulus((short)((n4 >> 4 | (array[n + 3 * i + 2] & 0xFF) << 4) & 0xFFF)));
        }
        return n2;
    }
}
