// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.mlkem;

import org.bouncycastle.util.Arrays;

class PolyVec
{
    Poly[] vec;
    private MLKEMEngine engine;
    private int kyberK;
    private int polyVecBytes;
    
    public PolyVec(final MLKEMEngine engine) {
        this.engine = engine;
        this.kyberK = engine.getKyberK();
        this.polyVecBytes = engine.getKyberPolyVecBytes();
        this.vec = new Poly[this.kyberK];
        for (int i = 0; i < this.kyberK; ++i) {
            this.vec[i] = new Poly(engine);
        }
    }
    
    public PolyVec() throws Exception {
        throw new Exception("Requires Parameter");
    }
    
    public Poly getVectorIndex(final int n) {
        return this.vec[n];
    }
    
    public void polyVecNtt() {
        for (int i = 0; i < this.kyberK; ++i) {
            this.getVectorIndex(i).polyNtt();
        }
    }
    
    public void polyVecInverseNttToMont() {
        for (int i = 0; i < this.kyberK; ++i) {
            this.getVectorIndex(i).polyInverseNttToMont();
        }
    }
    
    public byte[] compressPolyVec() {
        this.conditionalSubQ();
        final byte[] array = new byte[this.engine.getKyberPolyVecCompressedBytes()];
        int n = 0;
        if (this.engine.getKyberPolyVecCompressedBytes() == this.kyberK * 320) {
            final short[] array2 = new short[4];
            for (int i = 0; i < this.kyberK; ++i) {
                for (int j = 0; j < 64; ++j) {
                    for (int k = 0; k < 4; ++k) {
                        array2[k] = (short)((((long)this.getVectorIndex(i).getCoeffIndex(4 * j + k) << 10) + 1665L) * 1290167L >> 32 & 0x3FFL);
                    }
                    array[n + 0] = (byte)(array2[0] >> 0);
                    array[n + 1] = (byte)(array2[0] >> 8 | array2[1] << 2);
                    array[n + 2] = (byte)(array2[1] >> 6 | array2[2] << 4);
                    array[n + 3] = (byte)(array2[2] >> 4 | array2[3] << 6);
                    array[n + 4] = (byte)(array2[3] >> 2);
                    n += 5;
                }
            }
        }
        else {
            if (this.engine.getKyberPolyVecCompressedBytes() != this.kyberK * 352) {
                throw new RuntimeException("Kyber PolyVecCompressedBytes neither 320 * KyberK or 352 * KyberK!");
            }
            final short[] array3 = new short[8];
            for (int l = 0; l < this.kyberK; ++l) {
                for (int n2 = 0; n2 < 32; ++n2) {
                    for (int n3 = 0; n3 < 8; ++n3) {
                        array3[n3] = (short)((((long)this.getVectorIndex(l).getCoeffIndex(8 * n2 + n3) << 11) + 1664L) * 645084L >> 31 & 0x7FFL);
                    }
                    array[n + 0] = (byte)(array3[0] >> 0);
                    array[n + 1] = (byte)(array3[0] >> 8 | array3[1] << 3);
                    array[n + 2] = (byte)(array3[1] >> 5 | array3[2] << 6);
                    array[n + 3] = (byte)(array3[2] >> 2);
                    array[n + 4] = (byte)(array3[2] >> 10 | array3[3] << 1);
                    array[n + 5] = (byte)(array3[3] >> 7 | array3[4] << 4);
                    array[n + 6] = (byte)(array3[4] >> 4 | array3[5] << 7);
                    array[n + 7] = (byte)(array3[5] >> 1);
                    array[n + 8] = (byte)(array3[5] >> 9 | array3[6] << 2);
                    array[n + 9] = (byte)(array3[6] >> 6 | array3[7] << 5);
                    array[n + 10] = (byte)(array3[7] >> 3);
                    n += 11;
                }
            }
        }
        return array;
    }
    
    public void decompressPolyVec(final byte[] array) {
        int n = 0;
        if (this.engine.getKyberPolyVecCompressedBytes() == this.kyberK * 320) {
            final short[] array2 = new short[4];
            for (int i = 0; i < this.kyberK; ++i) {
                for (int j = 0; j < 64; ++j) {
                    array2[0] = (short)((array[n] & 0xFF) >> 0 | (short)((array[n + 1] & 0xFF) << 8));
                    array2[1] = (short)((array[n + 1] & 0xFF) >> 2 | (short)((array[n + 2] & 0xFF) << 6));
                    array2[2] = (short)((array[n + 2] & 0xFF) >> 4 | (short)((array[n + 3] & 0xFF) << 4));
                    array2[3] = (short)((array[n + 3] & 0xFF) >> 6 | (short)((array[n + 4] & 0xFF) << 2));
                    n += 5;
                    for (int k = 0; k < 4; ++k) {
                        this.vec[i].setCoeffIndex(4 * j + k, (short)((array2[k] & 0x3FF) * 3329 + 512 >> 10));
                    }
                }
            }
        }
        else {
            if (this.engine.getKyberPolyVecCompressedBytes() != this.kyberK * 352) {
                throw new RuntimeException("Kyber PolyVecCompressedBytes neither 320 * KyberK or 352 * KyberK!");
            }
            final short[] array3 = new short[8];
            for (int l = 0; l < this.kyberK; ++l) {
                for (int n2 = 0; n2 < 32; ++n2) {
                    array3[0] = (short)((array[n] & 0xFF) >> 0 | (short)(array[n + 1] & 0xFF) << 8);
                    array3[1] = (short)((array[n + 1] & 0xFF) >> 3 | (short)(array[n + 2] & 0xFF) << 5);
                    array3[2] = (short)((array[n + 2] & 0xFF) >> 6 | (short)(array[n + 3] & 0xFF) << 2 | (short)((array[n + 4] & 0xFF) << 10));
                    array3[3] = (short)((array[n + 4] & 0xFF) >> 1 | (short)(array[n + 5] & 0xFF) << 7);
                    array3[4] = (short)((array[n + 5] & 0xFF) >> 4 | (short)(array[n + 6] & 0xFF) << 4);
                    array3[5] = (short)((array[n + 6] & 0xFF) >> 7 | (short)(array[n + 7] & 0xFF) << 1 | (short)((array[n + 8] & 0xFF) << 9));
                    array3[6] = (short)((array[n + 8] & 0xFF) >> 2 | (short)(array[n + 9] & 0xFF) << 6);
                    array3[7] = (short)((array[n + 9] & 0xFF) >> 5 | (short)(array[n + 10] & 0xFF) << 3);
                    n += 11;
                    for (int n3 = 0; n3 < 8; ++n3) {
                        this.vec[l].setCoeffIndex(8 * n2 + n3, (short)((array3[n3] & 0x7FF) * 3329 + 1024 >> 11));
                    }
                }
            }
        }
    }
    
    public static void pointwiseAccountMontgomery(final Poly poly, final PolyVec polyVec, final PolyVec polyVec2, final MLKEMEngine mlkemEngine) {
        final Poly poly2 = new Poly(mlkemEngine);
        Poly.baseMultMontgomery(poly, polyVec.getVectorIndex(0), polyVec2.getVectorIndex(0));
        for (int i = 1; i < mlkemEngine.getKyberK(); ++i) {
            Poly.baseMultMontgomery(poly2, polyVec.getVectorIndex(i), polyVec2.getVectorIndex(i));
            poly.addCoeffs(poly2);
        }
        poly.reduce();
    }
    
    public void reducePoly() {
        for (int i = 0; i < this.kyberK; ++i) {
            this.getVectorIndex(i).reduce();
        }
    }
    
    public void addPoly(final PolyVec polyVec) {
        for (int i = 0; i < this.kyberK; ++i) {
            this.getVectorIndex(i).addCoeffs(polyVec.getVectorIndex(i));
        }
    }
    
    public byte[] toBytes() {
        final byte[] array = new byte[this.polyVecBytes];
        for (int i = 0; i < this.kyberK; ++i) {
            System.arraycopy(this.vec[i].toBytes(), 0, array, i * 384, 384);
        }
        return array;
    }
    
    public void fromBytes(final byte[] array) {
        for (int i = 0; i < this.kyberK; ++i) {
            this.getVectorIndex(i).fromBytes(Arrays.copyOfRange(array, i * 384, (i + 1) * 384));
        }
    }
    
    public void conditionalSubQ() {
        for (int i = 0; i < this.kyberK; ++i) {
            this.getVectorIndex(i).conditionalSubQ();
        }
    }
    
    @Override
    public String toString() {
        final StringBuilder sb = new StringBuilder();
        sb.append("[");
        for (int i = 0; i < this.kyberK; ++i) {
            sb.append(this.vec[i].toString());
            if (i != this.kyberK - 1) {
                sb.append(", ");
            }
        }
        sb.append("]");
        return sb.toString();
    }
    
    static int checkModulus(final MLKEMEngine mlkemEngine, final byte[] array) {
        int n = -1;
        for (int i = 0; i < mlkemEngine.getKyberK(); ++i) {
            n &= Poly.checkModulus(array, i * 384);
        }
        return n;
    }
}
