// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.ntruprime;

import org.bouncycastle.util.Arrays;
import org.bouncycastle.crypto.EncapsulatedSecretExtractor;

public class SNTRUPrimeKEMExtractor implements EncapsulatedSecretExtractor
{
    private final SNTRUPrimePrivateKeyParameters privateKey;
    
    public SNTRUPrimeKEMExtractor(final SNTRUPrimePrivateKeyParameters privateKey) {
        this.privateKey = privateKey;
    }
    
    @Override
    public byte[] extractSecret(final byte[] array) {
        final SNTRUPrimeParameters parameters = this.privateKey.getParameters();
        final int p = parameters.getP();
        final int q = parameters.getQ();
        final int w = parameters.getW();
        final int roundedPolynomialBytes = parameters.getRoundedPolynomialBytes();
        final byte[] array2 = new byte[p];
        Utils.getDecodedSmallPolynomial(array2, this.privateKey.getF(), p);
        final byte[] array3 = new byte[p];
        Utils.getDecodedSmallPolynomial(array3, this.privateKey.getGinv(), p);
        final short[] array4 = new short[p];
        Utils.getRoundedDecodedPolynomial(array4, array, p, q);
        final short[] array5 = new short[p];
        Utils.multiplicationInRQ(array5, array4, array2, p, q);
        final short[] array6 = new short[p];
        Utils.scalarMultiplicationInRQ(array6, array5, 3, q);
        final byte[] array7 = new byte[p];
        Utils.transformRQToR3(array7, array6);
        final byte[] array8 = new byte[p];
        Utils.multiplicationInR3(array8, array7, array3, p);
        final byte[] array9 = new byte[p];
        Utils.checkForSmallPolynomial(array9, array8, p, w);
        final byte[] array10 = new byte[(p + 3) / 4];
        Utils.getEncodedSmallPolynomial(array10, array9, p);
        final short[] array11 = new short[p];
        Utils.getDecodedPolynomial(array11, this.privateKey.getPk(), p, q);
        final short[] array12 = new short[p];
        Utils.multiplicationInRQ(array12, array11, array9, p, q);
        final short[] array13 = new short[p];
        Utils.roundPolynomial(array13, array12);
        final byte[] array14 = new byte[roundedPolynomialBytes];
        Utils.getRoundedEncodedPolynomial(array14, array13, p, q);
        final byte[] hashWithPrefix = Utils.getHashWithPrefix(new byte[] { 3 }, array10);
        final byte[] array15 = new byte[hashWithPrefix.length / 2 + this.privateKey.getHash().length];
        System.arraycopy(hashWithPrefix, 0, array15, 0, hashWithPrefix.length / 2);
        System.arraycopy(this.privateKey.getHash(), 0, array15, hashWithPrefix.length / 2, this.privateKey.getHash().length);
        final byte[] hashWithPrefix2 = Utils.getHashWithPrefix(new byte[] { 2 }, array15);
        final byte[] array16 = new byte[array14.length + hashWithPrefix2.length / 2];
        System.arraycopy(array14, 0, array16, 0, array14.length);
        System.arraycopy(hashWithPrefix2, 0, array16, array14.length, hashWithPrefix2.length / 2);
        final int n = Arrays.areEqual(array, array16) ? 0 : -1;
        Utils.updateDiffMask(array10, this.privateKey.getRho(), n);
        final byte[] hashWithPrefix3 = Utils.getHashWithPrefix(new byte[] { 3 }, array10);
        final byte[] array17 = new byte[hashWithPrefix3.length / 2 + array16.length];
        System.arraycopy(hashWithPrefix3, 0, array17, 0, hashWithPrefix3.length / 2);
        System.arraycopy(array16, 0, array17, hashWithPrefix3.length / 2, array16.length);
        return Arrays.copyOfRange(Utils.getHashWithPrefix(new byte[] { (byte)(n + 1) }, array17), 0, parameters.getSessionKeySize() / 8);
    }
    
    @Override
    public int getEncapsulationLength() {
        return this.privateKey.getParameters().getRoundedPolynomialBytes() + 32;
    }
}
