// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.ntruprime;

import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.modes.CTRModeCipher;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.params.ParametersWithIV;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.crypto.BlockCipher;
import org.bouncycastle.crypto.modes.SICBlockCipher;
import org.bouncycastle.crypto.engines.AESEngine;
import java.security.SecureRandom;

class Utils
{
    protected static int getRandomUnsignedInteger(final SecureRandom secureRandom) {
        final byte[] bytes = new byte[4];
        secureRandom.nextBytes(bytes);
        return bToUnsignedInt(bytes[0]) + (bToUnsignedInt(bytes[1]) << 8) + (bToUnsignedInt(bytes[2]) << 16) + (bToUnsignedInt(bytes[3]) << 24);
    }
    
    protected static void getRandomSmallPolynomial(final SecureRandom secureRandom, final byte[] array) {
        for (int i = 0; i < array.length; ++i) {
            array[i] = (byte)(((getRandomUnsignedInteger(secureRandom) & 0x3FFFFFFF) * 3 >>> 30) - 1);
        }
    }
    
    protected static int getModFreeze(final int n, final int n2) {
        return getSignedDivMod(n + (n2 - 1) / 2, n2)[1] - (n2 - 1) / 2;
    }
    
    protected static boolean isInvertiblePolynomialInR3(final byte[] array, final byte[] array2, final int n) {
        final byte[] array3 = new byte[n + 1];
        final byte[] array4 = new byte[n + 1];
        final byte[] array5 = new byte[n + 1];
        final byte[] array6 = new byte[n + 1];
        array3[0] = (array5[0] = 1);
        array3[n] = (array3[n - 1] = -1);
        for (int i = 0; i < n; ++i) {
            array4[n - 1 - i] = array[i];
        }
        array4[n] = 0;
        int n2 = 1;
        for (int j = 0; j < 2 * n - 1; ++j) {
            System.arraycopy(array6, 0, array6, 1, n);
            array6[0] = 0;
            final int n3 = -array4[0] * array3[0];
            final int n4 = checkLessThanZero(-n2) & checkNotEqualToZero(array4[0]);
            n2 ^= (n4 & (n2 ^ -n2));
            ++n2;
            for (int k = 0; k < n + 1; ++k) {
                final int n5 = n4 & (array3[k] ^ array4[k]);
                final byte[] array7 = array3;
                final int n6 = k;
                array7[n6] ^= (byte)n5;
                final byte[] array8 = array4;
                final int n7 = k;
                array8[n7] ^= (byte)n5;
                final int n8 = n4 & (array6[k] ^ array5[k]);
                final byte[] array9 = array6;
                final int n9 = k;
                array9[n9] ^= (byte)n8;
                final byte[] array10 = array5;
                final int n10 = k;
                array10[n10] ^= (byte)n8;
            }
            for (int l = 0; l < n + 1; ++l) {
                array4[l] = (byte)getModFreeze(array4[l] + n3 * array3[l], 3);
            }
            for (int n11 = 0; n11 < n + 1; ++n11) {
                array5[n11] = (byte)getModFreeze(array5[n11] + n3 * array6[n11], 3);
            }
            for (int n12 = 0; n12 < n; ++n12) {
                array4[n12] = array4[n12 + 1];
            }
            array4[n] = 0;
        }
        final byte b = array3[0];
        for (int n13 = 0; n13 < n; ++n13) {
            array2[n13] = (byte)(b * array6[n - 1 - n13]);
        }
        return n2 == 0;
    }
    
    protected static void minmax(final int[] array, final int n, final int n2) {
        final int n3 = array[n];
        final int n4 = array[n2];
        final int n5 = n3 ^ n4;
        final int n6 = n4 - n3;
        final int n7 = -((n6 ^ (n5 & (n6 ^ n4 ^ Integer.MIN_VALUE))) >>> 31) & n5;
        array[n] = (n3 ^ n7);
        array[n2] = (n4 ^ n7);
    }
    
    protected static void cryptoSort(final int[] array, final int n) {
        if (n < 2) {
            return;
        }
        int i;
        for (i = 1; i < n - i; i += i) {}
        for (int j = i; j > 0; j >>>= 1) {
            for (int k = 0; k < n - j; ++k) {
                if ((k & j) == 0x0) {
                    minmax(array, k, k + j);
                }
            }
            for (int l = i; l > j; l >>>= 1) {
                for (int n2 = 0; n2 < n - l; ++n2) {
                    if ((n2 & j) == 0x0) {
                        minmax(array, n2 + j, n2 + l);
                    }
                }
            }
        }
    }
    
    protected static void sortGenerateShortPolynomial(final byte[] array, final int[] array2, final int n, final int n2) {
        for (int i = 0; i < n2; ++i) {
            array2[i] &= 0xFFFFFFFE;
        }
        for (int j = n2; j < n; ++j) {
            array2[j] = ((array2[j] & 0xFFFFFFFD) | 0x1);
        }
        cryptoSort(array2, n);
        for (int k = 0; k < n; ++k) {
            array[k] = (byte)((array2[k] & 0x3) - 1);
        }
    }
    
    protected static void getRandomShortPolynomial(final SecureRandom secureRandom, final byte[] array, final int n, final int n2) {
        final int[] array2 = new int[n];
        for (int i = 0; i < n; ++i) {
            array2[i] = getRandomUnsignedInteger(secureRandom);
        }
        sortGenerateShortPolynomial(array, array2, n, n2);
    }
    
    protected static int getInverseInRQ(final int n, final int n2) {
        int modFreeze = n;
        for (int i = 1; i < n2 - 2; ++i) {
            modFreeze = getModFreeze(n * modFreeze, n2);
        }
        return modFreeze;
    }
    
    protected static void getOneThirdInverseInRQ(final short[] array, final byte[] array2, final int n, final int n2) {
        final short[] array3 = new short[n + 1];
        final short[] array4 = new short[n + 1];
        final short[] array5 = new short[n + 1];
        final short[] array6 = new short[n + 1];
        array5[0] = (short)getInverseInRQ(3, n2);
        array3[0] = 1;
        array3[n] = (array3[n - 1] = -1);
        for (int i = 0; i < n; ++i) {
            array4[n - 1 - i] = array2[i];
        }
        array4[n] = 0;
        int n3 = 1;
        for (int j = 0; j < 2 * n - 1; ++j) {
            System.arraycopy(array6, 0, array6, 1, n);
            array6[0] = 0;
            final int n4 = checkLessThanZero(-n3) & checkNotEqualToZero(array4[0]);
            n3 ^= (n4 & (n3 ^ -n3));
            ++n3;
            for (int k = 0; k < n + 1; ++k) {
                final int n5 = n4 & (array3[k] ^ array4[k]);
                final short[] array7 = array3;
                final int n6 = k;
                array7[n6] ^= (short)n5;
                final short[] array8 = array4;
                final int n7 = k;
                array8[n7] ^= (short)n5;
                final int n8 = n4 & (array6[k] ^ array5[k]);
                final short[] array9 = array6;
                final int n9 = k;
                array9[n9] ^= (short)n8;
                final short[] array10 = array5;
                final int n10 = k;
                array10[n10] ^= (short)n8;
            }
            final short n11 = array3[0];
            final short n12 = array4[0];
            for (int l = 0; l < n + 1; ++l) {
                array4[l] = (short)getModFreeze(n11 * array4[l] - n12 * array3[l], n2);
            }
            for (int n13 = 0; n13 < n + 1; ++n13) {
                array5[n13] = (short)getModFreeze(n11 * array5[n13] - n12 * array6[n13], n2);
            }
            for (int n14 = 0; n14 < n; ++n14) {
                array4[n14] = array4[n14 + 1];
            }
            array4[n] = 0;
        }
        final int inverseInRQ = getInverseInRQ(array3[0], n2);
        for (int n15 = 0; n15 < n; ++n15) {
            array[n15] = (short)getModFreeze(inverseInRQ * array6[n - 1 - n15], n2);
        }
    }
    
    protected static void multiplicationInRQ(final short[] array, final short[] array2, final byte[] array3, final int n, final int n2) {
        final short[] array4 = new short[n + n - 1];
        for (int i = 0; i < n; ++i) {
            short n3 = 0;
            for (int j = 0; j <= i; ++j) {
                n3 = (short)getModFreeze(n3 + array2[j] * array3[i - j], n2);
            }
            array4[i] = n3;
        }
        for (int k = n; k < n + n - 1; ++k) {
            short n4 = 0;
            for (int l = k - n + 1; l < n; ++l) {
                n4 = (short)getModFreeze(n4 + array2[l] * array3[k - l], n2);
            }
            array4[k] = n4;
        }
        for (int n5 = n + n - 2; n5 >= n; --n5) {
            array4[n5 - n] = (short)getModFreeze(array4[n5 - n] + array4[n5], n2);
            array4[n5 - n + 1] = (short)getModFreeze(array4[n5 - n + 1] + array4[n5], n2);
        }
        for (int n6 = 0; n6 < n; ++n6) {
            array[n6] = array4[n6];
        }
    }
    
    private static void encode(final byte[] array, final short[] array2, final short[] array3, final int n, int n2) {
        if (n == 1) {
            short n3 = array2[0];
            for (short n4 = array3[0]; n4 > 1; n4 = (short)(n4 + 255 >>> 8)) {
                array[n2++] = (byte)n3;
                n3 >>>= 8;
            }
        }
        if (n > 1) {
            final short[] array4 = new short[(n + 1) / 2];
            final short[] array5 = new short[(n + 1) / 2];
            int i;
            for (i = 0; i < n - 1; i += 2) {
                final short n5 = array3[i];
                int n6 = array2[i] + array2[i + 1] * n5;
                int j;
                for (j = array3[i + 1] * n5; j >= 16384; j = j + 255 >>> 8) {
                    array[n2++] = (byte)n6;
                    n6 >>>= 8;
                }
                array4[i / 2] = (short)n6;
                array5[i / 2] = (short)j;
            }
            if (i < n) {
                array4[i / 2] = array2[i];
                array5[i / 2] = array3[i];
            }
            encode(array, array4, array5, (n + 1) / 2, n2);
        }
    }
    
    protected static void getEncodedPolynomial(final byte[] array, final short[] array2, final int n, final int n2) {
        final short[] array3 = new short[n];
        final short[] array4 = new short[n];
        for (int i = 0; i < n; ++i) {
            array3[i] = (short)(array2[i] + (n2 - 1) / 2);
        }
        for (int j = 0; j < n; ++j) {
            array4[j] = (short)n2;
        }
        encode(array, array3, array4, n, 0);
    }
    
    protected static void getEncodedSmallPolynomial(final byte[] array, final byte[] array2, final int n) {
        int n2 = 0;
        int n3 = 0;
        for (int i = 0; i < n / 4; ++i) {
            array[n3++] = (byte)((byte)((byte)((byte)(array2[n2++] + 1) + ((byte)(array2[n2++] + 1) << 2)) + ((byte)(array2[n2++] + 1) << 4)) + ((byte)(array2[n2++] + 1) << 6));
        }
        array[n3] = (byte)(array2[n2] + 1);
    }
    
    private static void generateAES256CTRStream(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4) {
        final CTRModeCipher instance = SICBlockCipher.newInstance(AESEngine.newInstance());
        instance.init(true, new ParametersWithIV(new KeyParameter(array4), array3));
        instance.processBytes(array, 0, array2.length, array2, 0);
    }
    
    protected static void expand(final int[] array, final byte[] array2) {
        final byte[] array3 = new byte[array.length * 4];
        final byte[] array4 = new byte[array.length * 4];
        generateAES256CTRStream(array3, array4, new byte[16], array2);
        for (int i = 0; i < array.length; ++i) {
            array[i] = bToUnsignedInt(array4[i * 4]) + (bToUnsignedInt(array4[i * 4 + 1]) << 8) + (bToUnsignedInt(array4[i * 4 + 2]) << 16) + (bToUnsignedInt(array4[i * 4 + 3]) << 24);
        }
    }
    
    private static int getUnsignedMod(final int n, final int n2) {
        return getUnsignedDivMod(n, n2)[1];
    }
    
    protected static void generatePolynomialInRQFromSeed(final short[] array, final byte[] array2, final int n, final int n2) {
        final int[] array3 = new int[n];
        expand(array3, array2);
        for (int i = 0; i < n; ++i) {
            array[i] = (short)(getUnsignedMod(array3[i], n2) - (n2 - 1) / 2);
        }
    }
    
    protected static void roundPolynomial(final short[] array, final short[] array2) {
        for (int i = 0; i < array.length; ++i) {
            array[i] = (short)(array2[i] - getModFreeze(array2[i], 3));
        }
    }
    
    protected static void getRoundedEncodedPolynomial(final byte[] array, final short[] array2, final int n, final int n2) {
        final short[] array3 = new short[n];
        final short[] array4 = new short[n];
        for (int i = 0; i < n; ++i) {
            array3[i] = (short)((array2[i] + (n2 - 1) / 2) * 10923 >>> 15);
            array4[i] = (short)((n2 + 2) / 3);
        }
        encode(array, array3, array4, n, 0);
    }
    
    protected static byte[] getHashWithPrefix(final byte[] array, final byte[] array2) {
        final byte[] array3 = new byte[64];
        final byte[] array4 = new byte[array.length + array2.length];
        System.arraycopy(array, 0, array4, 0, array.length);
        System.arraycopy(array2, 0, array4, array.length, array2.length);
        final SHA512Digest sha512Digest = new SHA512Digest();
        sha512Digest.update(array4, 0, array4.length);
        sha512Digest.doFinal(array3, 0);
        return array3;
    }
    
    private static void decode(final short[] array, final byte[] array2, final short[] array3, final int n, int n2, int n3) {
        if (n == 1) {
            if (array3[0] == 1) {
                array[n2] = 0;
            }
            else if (array3[0] <= 256) {
                array[n2] = (short)getUnsignedMod(bToUnsignedInt(array2[n3]), array3[0]);
            }
            else {
                array[n2] = (short)getUnsignedMod(bToUnsignedInt(array2[n3]) + (array2[n3 + 1] << 8), array3[0]);
            }
        }
        if (n > 1) {
            final short[] array4 = new short[(n + 1) / 2];
            final short[] array5 = new short[(n + 1) / 2];
            final short[] array6 = new short[n / 2];
            final int[] array7 = new int[n / 2];
            int i;
            for (i = 0; i < n - 1; i += 2) {
                final int n4 = array3[i] * array3[i + 1];
                if (n4 > 4194048) {
                    array7[i / 2] = 65536;
                    array6[i / 2] = (short)(bToUnsignedInt(array2[n3]) + 256 * bToUnsignedInt(array2[n3 + 1]));
                    n3 += 2;
                    array5[i / 2] = (short)((n4 + 255 >>> 8) + 255 >>> 8);
                }
                else if (n4 >= 16384) {
                    array7[i / 2] = 256;
                    array6[i / 2] = (short)bToUnsignedInt(array2[n3]);
                    ++n3;
                    array5[i / 2] = (short)(n4 + 255 >>> 8);
                }
                else {
                    array7[i / 2] = 1;
                    array6[i / 2] = 0;
                    array5[i / 2] = (short)n4;
                }
            }
            if (i < n) {
                array5[i / 2] = array3[i];
            }
            decode(array4, array2, array5, (n + 1) / 2, n2, n3);
            int j;
            for (j = 0; j < n - 1; j += 2) {
                final int[] unsignedDivMod = getUnsignedDivMod(sToUnsignedInt(array6[j / 2]) + array7[j / 2] * sToUnsignedInt(array4[j / 2]), array3[j]);
                array[n2++] = (short)unsignedDivMod[1];
                array[n2++] = (short)getUnsignedMod(unsignedDivMod[0], array3[j + 1]);
            }
            if (j < n) {
                array[n2] = array4[j / 2];
            }
        }
    }
    
    protected static void getDecodedPolynomial(final short[] array, final byte[] array2, final int n, final int n2) {
        final short[] array3 = new short[n];
        final short[] array4 = new short[n];
        for (int i = 0; i < n; ++i) {
            array4[i] = (short)n2;
        }
        decode(array3, array2, array4, n, 0, 0);
        for (int j = 0; j < n; ++j) {
            array[j] = (short)(array3[j] - (n2 - 1) / 2);
        }
    }
    
    protected static void getRandomInputs(final SecureRandom secureRandom, final byte[] array) {
        final byte[] bytes = new byte[array.length / 8];
        secureRandom.nextBytes(bytes);
        for (int i = 0; i < array.length; ++i) {
            array[i] = (byte)(0x1 & bytes[i >>> 3] >>> (i & 0x7));
        }
    }
    
    protected static void getEncodedInputs(final byte[] array, final byte[] array2) {
        for (int i = 0; i < array2.length; ++i) {
            final int n = i >>> 3;
            array[n] |= (byte)(array2[i] << (i & 0x7));
        }
    }
    
    protected static void getRoundedDecodedPolynomial(final short[] array, final byte[] array2, final int n, final int n2) {
        final short[] array3 = new short[n];
        final short[] array4 = new short[n];
        for (int i = 0; i < n; ++i) {
            array4[i] = (short)((n2 + 2) / 3);
        }
        decode(array3, array2, array4, n, 0, 0);
        for (int j = 0; j < n; ++j) {
            array[j] = (short)(array3[j] * 3 - (n2 - 1) / 2);
        }
    }
    
    protected static void top(final byte[] array, final short[] array2, final byte[] array3, final int n, final int n2, final int n3) {
        for (int i = 0; i < array.length; ++i) {
            array[i] = (byte)(n3 * (getModFreeze(array2[i] + array3[i] * ((n - 1) / 2), n) + n2) + 16384 >>> 15);
        }
    }
    
    protected static void getTopEncodedPolynomial(final byte[] array, final byte[] array2) {
        for (int i = 0; i < array.length; ++i) {
            array[i] = (byte)(array2[2 * i] + (array2[2 * i + 1] << 4));
        }
    }
    
    protected static void getDecodedSmallPolynomial(final byte[] array, final byte[] array2, final int n) {
        int n2 = 0;
        int n3 = 0;
        for (int i = 0; i < n / 4; ++i) {
            final byte b = array2[n3++];
            array[n2++] = (byte)((bToUnsignedInt(b) & 0x3) - 1);
            final byte b2 = (byte)(b >>> 2);
            array[n2++] = (byte)((bToUnsignedInt(b2) & 0x3) - 1);
            final byte b3 = (byte)(b2 >>> 2);
            array[n2++] = (byte)((bToUnsignedInt(b3) & 0x3) - 1);
            array[n2++] = (byte)((bToUnsignedInt((byte)(b3 >>> 2)) & 0x3) - 1);
        }
        array[n2] = (byte)((bToUnsignedInt(array2[n3]) & 0x3) - 1);
    }
    
    protected static void scalarMultiplicationInRQ(final short[] array, final short[] array2, final int n, final int n2) {
        for (int i = 0; i < array2.length; ++i) {
            array[i] = (short)getModFreeze(n * array2[i], n2);
        }
    }
    
    protected static void transformRQToR3(final byte[] array, final short[] array2) {
        for (int i = 0; i < array2.length; ++i) {
            array[i] = (byte)getModFreeze(array2[i], 3);
        }
    }
    
    protected static void multiplicationInR3(final byte[] array, final byte[] array2, final byte[] array3, final int n) {
        final byte[] array4 = new byte[n + n - 1];
        for (int i = 0; i < n; ++i) {
            byte b = 0;
            for (int j = 0; j <= i; ++j) {
                b = (byte)getModFreeze(b + array2[j] * array3[i - j], 3);
            }
            array4[i] = b;
        }
        for (int k = n; k < n + n - 1; ++k) {
            byte b2 = 0;
            for (int l = k - n + 1; l < n; ++l) {
                b2 = (byte)getModFreeze(b2 + array2[l] * array3[k - l], 3);
            }
            array4[k] = b2;
        }
        for (int n2 = n + n - 2; n2 >= n; --n2) {
            array4[n2 - n] = (byte)getModFreeze(array4[n2 - n] + array4[n2], 3);
            array4[n2 - n + 1] = (byte)getModFreeze(array4[n2 - n + 1] + array4[n2], 3);
        }
        for (int n3 = 0; n3 < n; ++n3) {
            array[n3] = array4[n3];
        }
    }
    
    protected static void checkForSmallPolynomial(final byte[] array, final byte[] array2, final int n, final int n2) {
        int n3 = 0;
        for (int i = 0; i != array2.length; ++i) {
            n3 += (array2[i] & 0x1);
        }
        final int checkNotEqualToZero = checkNotEqualToZero(n3 - n2);
        for (int j = 0; j < n2; ++j) {
            array[j] = (byte)(((array2[j] ^ 0x1) & ~checkNotEqualToZero) ^ 0x1);
        }
        for (int k = n2; k < n; ++k) {
            array[k] = (byte)(array2[k] & ~checkNotEqualToZero);
        }
    }
    
    protected static void updateDiffMask(final byte[] array, final byte[] array2, final int n) {
        for (int i = 0; i < array.length; ++i) {
            final int n2 = i;
            array[n2] ^= (byte)(n & (array[i] ^ array2[i]));
        }
    }
    
    protected static void getTopDecodedPolynomial(final byte[] array, final byte[] array2) {
        for (int i = 0; i < array2.length; ++i) {
            array[2 * i] = (byte)(array2[i] & 0xF);
            array[2 * i + 1] = (byte)(array2[i] >>> 4);
        }
    }
    
    protected static void right(final byte[] array, final short[] array2, final byte[] array3, final int n, final int n2, final int n3, final int n4) {
        for (int i = 0; i < array.length; ++i) {
            array[i] = (byte)(-checkLessThanZero(getModFreeze(getModFreeze(n4 * array3[i] - n3, n) - array2[i] + 4 * n2 + 1, n)));
        }
    }
    
    private static int[] getUnsignedDivMod(final int n, final int n2) {
        final long iToUnsignedLong = iToUnsignedLong(n);
        final long n3 = iToUnsignedLong(Integer.MIN_VALUE) / n2;
        final long n4 = 0L;
        final long n5 = iToUnsignedLong * n3 >>> 31;
        final long n6 = iToUnsignedLong - n5 * n2;
        final int n7 = (int)(n4 + n5);
        final long n8 = n6 * n3 >>> 31;
        final long n9 = n6 - n8 * n2;
        final int n10 = (int)(n7 + n8);
        final long n11 = n9 - n2;
        final long n12 = n10 + 1L;
        final long n13 = -(n11 >>> 63);
        return new int[] { toIntExact(n12 + n13), toIntExact(n11 + (n13 & (long)n2)) };
    }
    
    private static int[] getSignedDivMod(final int n, final int n2) {
        final int[] unsignedDivMod = getUnsignedDivMod(toIntExact(-2147483648L + iToUnsignedLong(n)), n2);
        final int[] unsignedDivMod2 = getUnsignedDivMod(Integer.MIN_VALUE, n2);
        final int intExact = toIntExact(iToUnsignedLong(unsignedDivMod[0]) - iToUnsignedLong(unsignedDivMod2[0]));
        final int intExact2 = toIntExact(iToUnsignedLong(unsignedDivMod[1]) - iToUnsignedLong(unsignedDivMod2[1]));
        final int n3 = -(intExact2 >>> 31);
        return new int[] { intExact + n3, intExact2 + (n3 & n2) };
    }
    
    private static int checkLessThanZero(final int n) {
        return -(n >>> 31);
    }
    
    private static int checkNotEqualToZero(final int n) {
        return -(int)(-iToUnsignedLong(n) >>> 63);
    }
    
    static int bToUnsignedInt(final byte b) {
        return b & 0xFF;
    }
    
    static int sToUnsignedInt(final short n) {
        return n & 0xFFFF;
    }
    
    static long iToUnsignedLong(final int n) {
        return (long)n & 0xFFFFFFFFL;
    }
    
    static int toIntExact(final long n) {
        final int n2 = (int)n;
        if (n2 != n) {
            throw new IllegalStateException("value out of integer range");
        }
        return n2;
    }
}
