// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.bike;

import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.Bytes;
import java.security.SecureRandom;
import org.bouncycastle.crypto.digests.SHA3Digest;
import org.bouncycastle.crypto.Xof;
import org.bouncycastle.crypto.digests.SHAKEDigest;

class BIKEEngine
{
    private int r;
    private int w;
    private int hw;
    private int t;
    private int nbIter;
    private int tau;
    private final BIKERing bikeRing;
    private int L_BYTE;
    private int R_BYTE;
    private int R2_BYTE;
    
    public BIKEEngine(final int r, final int w, final int t, final int n, final int nbIter, final int tau) {
        this.r = r;
        this.w = w;
        this.t = t;
        this.nbIter = nbIter;
        this.tau = tau;
        this.hw = this.w / 2;
        this.L_BYTE = n / 8;
        this.R_BYTE = r + 7 >>> 3;
        this.R2_BYTE = 2 * r + 7 >>> 3;
        this.bikeRing = new BIKERing(r);
    }
    
    public int getSessionKeySize() {
        return this.L_BYTE;
    }
    
    private byte[] functionH(final byte[] array) {
        final byte[] array2 = new byte[2 * this.R_BYTE];
        final SHAKEDigest shakeDigest = new SHAKEDigest(256);
        shakeDigest.update(array, 0, array.length);
        BIKEUtils.generateRandomByteArray(array2, 2 * this.r, this.t, shakeDigest);
        return array2;
    }
    
    private void functionL(final byte[] array, final byte[] array2, final byte[] array3) {
        final byte[] array4 = new byte[48];
        final SHA3Digest sha3Digest = new SHA3Digest(384);
        sha3Digest.update(array, 0, array.length);
        sha3Digest.update(array2, 0, array2.length);
        sha3Digest.doFinal(array4, 0);
        System.arraycopy(array4, 0, array3, 0, this.L_BYTE);
    }
    
    private void functionK(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4) {
        final byte[] array5 = new byte[48];
        final SHA3Digest sha3Digest = new SHA3Digest(384);
        sha3Digest.update(array, 0, array.length);
        sha3Digest.update(array2, 0, array2.length);
        sha3Digest.update(array3, 0, array3.length);
        sha3Digest.doFinal(array5, 0);
        System.arraycopy(array5, 0, array4, 0, this.L_BYTE);
    }
    
    public void genKeyPair(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4, final SecureRandom secureRandom) {
        final byte[] bytes = new byte[64];
        secureRandom.nextBytes(bytes);
        final SHAKEDigest shakeDigest = new SHAKEDigest(256);
        shakeDigest.update(bytes, 0, this.L_BYTE);
        BIKEUtils.generateRandomByteArray(array, this.r, this.hw, shakeDigest);
        BIKEUtils.generateRandomByteArray(array2, this.r, this.hw, shakeDigest);
        final long[] create = this.bikeRing.create();
        final long[] create2 = this.bikeRing.create();
        this.bikeRing.decodeBytes(array, create);
        this.bikeRing.decodeBytes(array2, create2);
        final long[] create3 = this.bikeRing.create();
        this.bikeRing.inv(create, create3);
        this.bikeRing.multiply(create3, create2, create3);
        this.bikeRing.encodeBytes(create3, array4);
        System.arraycopy(bytes, this.L_BYTE, array3, 0, array3.length);
    }
    
    public void encaps(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4, final SecureRandom secureRandom) {
        final byte[] bytes = new byte[this.L_BYTE];
        secureRandom.nextBytes(bytes);
        final byte[] functionH = this.functionH(bytes);
        final byte[] array5 = new byte[this.R_BYTE];
        final byte[] array6 = new byte[this.R_BYTE];
        this.splitEBytes(functionH, array5, array6);
        final long[] create = this.bikeRing.create();
        final long[] create2 = this.bikeRing.create();
        this.bikeRing.decodeBytes(array5, create);
        this.bikeRing.decodeBytes(array6, create2);
        final long[] create3 = this.bikeRing.create();
        this.bikeRing.decodeBytes(array4, create3);
        this.bikeRing.multiply(create3, create2, create3);
        this.bikeRing.add(create3, create, create3);
        this.bikeRing.encodeBytes(create3, array);
        this.functionL(array5, array6, array2);
        Bytes.xorTo(this.L_BYTE, bytes, array2);
        this.functionK(bytes, array, array2, array3);
    }
    
    public void decaps(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4, final byte[] array5, final byte[] array6) {
        final int[] array7 = new int[this.hw];
        final int[] array8 = new int[this.hw];
        this.convertToCompact(array7, array2);
        this.convertToCompact(array8, array3);
        final byte[] bgfDecoder = this.BGFDecoder(this.computeSyndrome(array5, array2), array7, array8);
        final byte[] array9 = new byte[2 * this.R_BYTE];
        BIKEUtils.fromBitArrayToByteArray(array9, bgfDecoder, 0, 2 * this.r);
        final byte[] array10 = new byte[this.R_BYTE];
        final byte[] array11 = new byte[this.R_BYTE];
        this.splitEBytes(array9, array10, array11);
        final byte[] array12 = new byte[this.L_BYTE];
        this.functionL(array10, array11, array12);
        Bytes.xorTo(this.L_BYTE, array6, array12);
        if (Arrays.areEqual(array9, 0, this.R2_BYTE, this.functionH(array12), 0, this.R2_BYTE)) {
            this.functionK(array12, array5, array6, array);
        }
        else {
            this.functionK(array4, array5, array6, array);
        }
    }
    
    private byte[] computeSyndrome(final byte[] array, final byte[] array2) {
        final long[] create = this.bikeRing.create();
        final long[] create2 = this.bikeRing.create();
        this.bikeRing.decodeBytes(array, create);
        this.bikeRing.decodeBytes(array2, create2);
        this.bikeRing.multiply(create, create2, create);
        return this.bikeRing.encodeBitsTransposed(create);
    }
    
    private byte[] BGFDecoder(final byte[] array, final int[] array2, final int[] array3) {
        final byte[] array4 = new byte[2 * this.r];
        final int[] columnFromCompactVersion = this.getColumnFromCompactVersion(array2);
        final int[] columnFromCompactVersion2 = this.getColumnFromCompactVersion(array3);
        final byte[] array5 = new byte[2 * this.r];
        final byte[] array6 = new byte[this.r];
        final byte[] array7 = new byte[2 * this.r];
        this.BFIter(array, array4, this.threshold(BIKEUtils.getHammingWeight(array), this.r), array2, array3, columnFromCompactVersion, columnFromCompactVersion2, array5, array7, array6);
        this.BFMaskedIter(array, array4, array5, (this.hw + 1) / 2 + 1, array2, array3, columnFromCompactVersion, columnFromCompactVersion2);
        this.BFMaskedIter(array, array4, array7, (this.hw + 1) / 2 + 1, array2, array3, columnFromCompactVersion, columnFromCompactVersion2);
        for (int i = 1; i < this.nbIter; ++i) {
            Arrays.fill(array5, (byte)0);
            this.BFIter2(array, array4, this.threshold(BIKEUtils.getHammingWeight(array), this.r), array2, array3, columnFromCompactVersion, columnFromCompactVersion2, array6);
        }
        if (BIKEUtils.getHammingWeight(array) == 0) {
            return array4;
        }
        return null;
    }
    
    private void BFIter(final byte[] array, final byte[] array2, final int n, final int[] array3, final int[] array4, final int[] array5, final int[] array6, final byte[] array7, final byte[] array8, final byte[] array9) {
        this.ctrAll(array5, array, array9);
        final int n2 = array9[0] & 0xFF;
        final int n3 = (n2 - n >> 31) + 1;
        final int n4 = (n2 - (n - this.tau) >> 31) + 1;
        final int n5 = 0;
        array2[n5] ^= (byte)n3;
        array7[0] = (byte)n3;
        array8[0] = (byte)n4;
        for (int i = 1; i < this.r; ++i) {
            final int n6 = array9[i] & 0xFF;
            final int n7 = (n6 - n >> 31) + 1;
            final int n8 = (n6 - (n - this.tau) >> 31) + 1;
            final int n9 = this.r - i;
            array2[n9] ^= (byte)n7;
            array7[i] = (byte)n7;
            array8[i] = (byte)n8;
        }
        this.ctrAll(array6, array, array9);
        final int n10 = array9[0] & 0xFF;
        final int n11 = (n10 - n >> 31) + 1;
        final int n12 = (n10 - (n - this.tau) >> 31) + 1;
        final int r = this.r;
        array2[r] ^= (byte)n11;
        array7[this.r] = (byte)n11;
        array8[this.r] = (byte)n12;
        for (int j = 1; j < this.r; ++j) {
            final int n13 = array9[j] & 0xFF;
            final int n14 = (n13 - n >> 31) + 1;
            final int n15 = (n13 - (n - this.tau) >> 31) + 1;
            final int n16 = this.r + this.r - j;
            array2[n16] ^= (byte)n14;
            array7[this.r + j] = (byte)n14;
            array8[this.r + j] = (byte)n15;
        }
        for (int k = 0; k < 2 * this.r; ++k) {
            this.recomputeSyndrome(array, k, array3, array4, array7[k] != 0);
        }
    }
    
    private void BFIter2(final byte[] array, final byte[] array2, final int n, final int[] array3, final int[] array4, final int[] array5, final int[] array6, final byte[] array7) {
        final int[] array8 = new int[2 * this.r];
        this.ctrAll(array5, array, array7);
        final int n2 = ((array7[0] & 0xFF) - n >> 31) + 1;
        final int n3 = 0;
        array2[n3] ^= (byte)n2;
        array8[0] = n2;
        for (int i = 1; i < this.r; ++i) {
            final int n4 = ((array7[i] & 0xFF) - n >> 31) + 1;
            final int n5 = this.r - i;
            array2[n5] ^= (byte)n4;
            array8[i] = n4;
        }
        this.ctrAll(array6, array, array7);
        final int n6 = ((array7[0] & 0xFF) - n >> 31) + 1;
        final int r = this.r;
        array2[r] ^= (byte)n6;
        array8[this.r] = n6;
        for (int j = 1; j < this.r; ++j) {
            final int n7 = ((array7[j] & 0xFF) - n >> 31) + 1;
            final int n8 = this.r + this.r - j;
            array2[n8] ^= (byte)n7;
            array8[this.r + j] = n7;
        }
        for (int k = 0; k < 2 * this.r; ++k) {
            this.recomputeSyndrome(array, k, array3, array4, array8[k] == 1);
        }
    }
    
    private void BFMaskedIter(final byte[] array, final byte[] array2, final byte[] array3, final int n, final int[] array4, final int[] array5, final int[] array6, final int[] array7) {
        final int[] array8 = new int[2 * this.r];
        for (int i = 0; i < this.r; ++i) {
            if (array3[i] == 1) {
                final boolean b = this.ctr(array6, array, i) >= n;
                this.updateNewErrorIndex(array2, i, b);
                array8[i] = (b ? 1 : 0);
            }
        }
        for (int j = 0; j < this.r; ++j) {
            if (array3[this.r + j] == 1) {
                final boolean b2 = this.ctr(array7, array, j) >= n;
                this.updateNewErrorIndex(array2, this.r + j, b2);
                array8[this.r + j] = (b2 ? 1 : 0);
            }
        }
        for (int k = 0; k < 2 * this.r; ++k) {
            this.recomputeSyndrome(array, k, array4, array5, array8[k] == 1);
        }
    }
    
    private int threshold(final int n, final int n2) {
        switch (n2) {
            case 12323: {
                return thresholdFromParameters(n, 0.0069722, 13.53, 36);
            }
            case 24659: {
                return thresholdFromParameters(n, 0.005265, 15.2588, 52);
            }
            case 40973: {
                return thresholdFromParameters(n, 0.00402312, 17.8785, 69);
            }
            default: {
                throw new IllegalArgumentException();
            }
        }
    }
    
    private static int thresholdFromParameters(final int n, final double n2, final double n3, final int a) {
        return Math.max(a, (int)Math.floor(n2 * n + n3));
    }
    
    private int ctr(final int[] array, final byte[] array2, final int n) {
        int n2 = 0;
        int i;
        for (i = 0; i <= this.hw - 4; i += 4) {
            final int n3 = array[i + 0] + n - this.r;
            final int n4 = array[i + 1] + n - this.r;
            final int n5 = array[i + 2] + n - this.r;
            final int n6 = array[i + 3] + n - this.r;
            n2 = n2 + (array2[n3 + (n3 >> 31 & this.r)] & 0xFF) + (array2[n4 + (n4 >> 31 & this.r)] & 0xFF) + (array2[n5 + (n5 >> 31 & this.r)] & 0xFF) + (array2[n6 + (n6 >> 31 & this.r)] & 0xFF);
        }
        while (i < this.hw) {
            final int n7 = array[i] + n - this.r;
            n2 += (array2[n7 + (n7 >> 31 & this.r)] & 0xFF);
            ++i;
        }
        return n2;
    }
    
    private void ctrAll(final int[] array, final byte[] array2, final byte[] array3) {
        final int n = array[0];
        final int n2 = this.r - n;
        System.arraycopy(array2, n, array3, 0, n2);
        System.arraycopy(array2, 0, array3, n2, n);
        for (int i = 1; i < this.hw; ++i) {
            final int n3 = array[i];
            int n4;
            int j;
            for (n4 = this.r - n3, j = 0; j <= n4 - 4; j += 4) {
                final int n5 = j + 0;
                array3[n5] += (byte)(array2[n3 + j + 0] & 0xFF);
                final int n6 = j + 1;
                array3[n6] += (byte)(array2[n3 + j + 1] & 0xFF);
                final int n7 = j + 2;
                array3[n7] += (byte)(array2[n3 + j + 2] & 0xFF);
                final int n8 = j + 3;
                array3[n8] += (byte)(array2[n3 + j + 3] & 0xFF);
            }
            while (j < n4) {
                final int n9 = j;
                array3[n9] += (byte)(array2[n3 + j] & 0xFF);
                ++j;
            }
            int k;
            for (k = n4; k <= this.r - 4; k += 4) {
                final int n10 = k + 0;
                array3[n10] += (byte)(array2[k + 0 - n4] & 0xFF);
                final int n11 = k + 1;
                array3[n11] += (byte)(array2[k + 1 - n4] & 0xFF);
                final int n12 = k + 2;
                array3[n12] += (byte)(array2[k + 2 - n4] & 0xFF);
                final int n13 = k + 3;
                array3[n13] += (byte)(array2[k + 3 - n4] & 0xFF);
            }
            while (k < this.r) {
                final int n14 = k;
                array3[n14] += (byte)(array2[k - n4] & 0xFF);
                ++k;
            }
        }
    }
    
    private void convertToCompact(final int[] array, final byte[] array2) {
        int n = 0;
        for (int i = 0; i < this.R_BYTE; ++i) {
            for (int n2 = 0; n2 < 8 && i * 8 + n2 != this.r; ++n2) {
                final int n3 = array2[i] >> n2 & 0x1;
                array[n] = ((i * 8 + n2 & -n3) | (array[n] & ~(-n3)));
                n = (n + n3) % this.hw;
            }
        }
    }
    
    private int[] getColumnFromCompactVersion(final int[] array) {
        final int[] array2 = new int[this.hw];
        if (array[0] == 0) {
            array2[0] = 0;
            for (int i = 1; i < this.hw; ++i) {
                array2[i] = this.r - array[this.hw - i];
            }
        }
        else {
            for (int j = 0; j < this.hw; ++j) {
                array2[j] = this.r - array[this.hw - 1 - j];
            }
        }
        return array2;
    }
    
    private void recomputeSyndrome(final byte[] array, final int n, final int[] array2, final int[] array3, final boolean b) {
        final boolean b2 = b;
        if (n < this.r) {
            for (int i = 0; i < this.hw; ++i) {
                if (array2[i] <= n) {
                    final int n2 = n - array2[i];
                    array[n2] ^= (byte)(b2 ? 1 : 0);
                }
                else {
                    final int n3 = this.r + n - array2[i];
                    array[n3] ^= (byte)(b2 ? 1 : 0);
                }
            }
        }
        else {
            for (int j = 0; j < this.hw; ++j) {
                if (array3[j] <= n - this.r) {
                    final int n4 = n - this.r - array3[j];
                    array[n4] ^= (byte)(b2 ? 1 : 0);
                }
                else {
                    final int n5 = this.r - array3[j] + (n - this.r);
                    array[n5] ^= (byte)(b2 ? 1 : 0);
                }
            }
        }
    }
    
    private void splitEBytes(final byte[] array, final byte[] array2, final byte[] array3) {
        final int n = this.r & 0x7;
        System.arraycopy(array, 0, array2, 0, this.R_BYTE - 1);
        final byte b = array[this.R_BYTE - 1];
        final byte b2 = (byte)(-1 << n);
        array2[this.R_BYTE - 1] = (byte)(b & ~b2);
        int n2 = (byte)(b & b2);
        for (int i = 0; i < this.R_BYTE; ++i) {
            final byte b3 = array[this.R_BYTE + i];
            array3[i] = (byte)(b3 << 8 - n | (n2 & 0xFF) >>> n);
            n2 = b3;
        }
    }
    
    private void updateNewErrorIndex(final byte[] array, final int n, final boolean b) {
        int n2 = n;
        if (n != 0 && n != this.r) {
            if (n > this.r) {
                n2 = 2 * this.r - n + this.r;
            }
            else {
                n2 = this.r - n;
            }
        }
        final int n3 = n2;
        array[n3] ^= (byte)(b ? 1 : 0);
    }
}
