// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.sphincsplus;

import org.bouncycastle.util.Pack;
import org.bouncycastle.util.Arrays;

class WotsPlus
{
    private final SPHINCSPlusEngine engine;
    private final int w;
    
    WotsPlus(final SPHINCSPlusEngine engine) {
        this.engine = engine;
        this.w = this.engine.WOTS_W;
    }
    
    byte[] pkGen(final byte[] array, final byte[] array2, final ADRS adrs) {
        final ADRS adrs2 = new ADRS(adrs);
        final byte[][] array3 = new byte[this.engine.WOTS_LEN][];
        for (int i = 0; i < this.engine.WOTS_LEN; ++i) {
            final ADRS adrs3 = new ADRS(adrs);
            adrs3.setTypeAndClear(5);
            adrs3.setKeyPairAddress(adrs.getKeyPairAddress());
            adrs3.setChainAddress(i);
            adrs3.setHashAddress(0);
            final byte[] prf = this.engine.PRF(array2, array, adrs3);
            adrs3.setTypeAndClear(0);
            adrs3.setKeyPairAddress(adrs.getKeyPairAddress());
            adrs3.setChainAddress(i);
            adrs3.setHashAddress(0);
            array3[i] = this.chain(prf, 0, this.w - 1, array2, adrs3);
        }
        adrs2.setTypeAndClear(1);
        adrs2.setKeyPairAddress(adrs.getKeyPairAddress());
        return this.engine.T_l(array2, adrs2, Arrays.concatenate(array3));
    }
    
    byte[] chain(final byte[] array, final int n, final int n2, final byte[] array2, final ADRS adrs) {
        if (n2 == 0) {
            return Arrays.clone(array);
        }
        if (n + n2 > this.w - 1) {
            return null;
        }
        byte[] f = array;
        for (int i = 0; i < n2; ++i) {
            adrs.setHashAddress(n + i);
            f = this.engine.F(array2, adrs, f);
        }
        return f;
    }
    
    public byte[] sign(final byte[] array, final byte[] array2, final byte[] array3, final ADRS adrs) {
        final ADRS adrs2 = new ADRS(adrs);
        final int[] array4 = new int[this.engine.WOTS_LEN];
        this.base_w(array, 0, this.w, array4, 0, this.engine.WOTS_LEN1);
        int n = 0;
        for (int i = 0; i < this.engine.WOTS_LEN1; ++i) {
            n += this.w - 1 - array4[i];
        }
        if (this.engine.WOTS_LOGW % 8 != 0) {
            n <<= 8 - this.engine.WOTS_LEN2 * this.engine.WOTS_LOGW % 8;
        }
        this.base_w(Pack.intToBigEndian(n), 4 - (this.engine.WOTS_LEN2 * this.engine.WOTS_LOGW + 7) / 8, this.w, array4, this.engine.WOTS_LEN1, this.engine.WOTS_LEN2);
        final byte[][] array5 = new byte[this.engine.WOTS_LEN][];
        for (int j = 0; j < this.engine.WOTS_LEN; ++j) {
            adrs2.setTypeAndClear(5);
            adrs2.setKeyPairAddress(adrs.getKeyPairAddress());
            adrs2.setChainAddress(j);
            adrs2.setHashAddress(0);
            final byte[] prf = this.engine.PRF(array3, array2, adrs2);
            adrs2.setTypeAndClear(0);
            adrs2.setKeyPairAddress(adrs.getKeyPairAddress());
            adrs2.setChainAddress(j);
            adrs2.setHashAddress(0);
            array5[j] = this.chain(prf, 0, array4[j], array3, adrs2);
        }
        return Arrays.concatenate(array5);
    }
    
    void base_w(final byte[] array, int n, final int n2, final int[] array2, int n3, final int n4) {
        int n5 = 0;
        int n6 = 0;
        for (int i = 0; i < n4; ++i) {
            if (n6 == 0) {
                n5 = array[n++];
                n6 += 8;
            }
            n6 -= this.engine.WOTS_LOGW;
            array2[n3++] = (n5 >>> n6 & n2 - 1);
        }
    }
    
    public byte[] pkFromSig(final byte[] array, final byte[] array2, final byte[] array3, final ADRS adrs) {
        final ADRS adrs2 = new ADRS(adrs);
        final int[] array4 = new int[this.engine.WOTS_LEN];
        this.base_w(array2, 0, this.w, array4, 0, this.engine.WOTS_LEN1);
        int n = 0;
        for (int i = 0; i < this.engine.WOTS_LEN1; ++i) {
            n += this.w - 1 - array4[i];
        }
        this.base_w(Pack.intToBigEndian(n << 8 - this.engine.WOTS_LEN2 * this.engine.WOTS_LOGW % 8), 4 - (this.engine.WOTS_LEN2 * this.engine.WOTS_LOGW + 7) / 8, this.w, array4, this.engine.WOTS_LEN1, this.engine.WOTS_LEN2);
        final byte[] array5 = new byte[this.engine.N];
        final byte[][] array6 = new byte[this.engine.WOTS_LEN][];
        for (int j = 0; j < this.engine.WOTS_LEN; ++j) {
            adrs.setChainAddress(j);
            System.arraycopy(array, j * this.engine.N, array5, 0, this.engine.N);
            array6[j] = this.chain(array5, array4[j], this.w - 1 - array4[j], array3, adrs);
        }
        adrs2.setTypeAndClear(1);
        adrs2.setKeyPairAddress(adrs.getKeyPairAddress());
        return this.engine.T_l(array3, adrs2, Arrays.concatenate(array6));
    }
}
