// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.slhdsa;

import java.math.BigInteger;
import org.bouncycastle.util.Arrays;
import java.util.LinkedList;

class Fors
{
    SLHDSAEngine engine;
    
    public Fors(final SLHDSAEngine engine) {
        this.engine = engine;
    }
    
    byte[] treehash(final byte[] array, final int n, final int n2, final byte[] array2, final ADRS adrs) {
        if (n >>> n2 << n2 != n) {
            return null;
        }
        final LinkedList list = new LinkedList();
        final ADRS adrs2 = new ADRS(adrs);
        for (int i = 0; i < 1 << n2; ++i) {
            adrs2.setTypeAndClear(6);
            adrs2.setKeyPairAddress(adrs.getKeyPairAddress());
            adrs2.setTreeHeight(0);
            adrs2.setTreeIndex(n + i);
            final byte[] prf = this.engine.PRF(array2, array, adrs2);
            adrs2.changeType(3);
            byte[] array3 = this.engine.F(array2, adrs2, prf);
            adrs2.setTreeHeight(1);
            int n3 = 1;
            int treeIndex = n + i;
            while (!list.isEmpty() && list.get(0).nodeHeight == n3) {
                treeIndex = (treeIndex - 1) / 2;
                adrs2.setTreeIndex(treeIndex);
                array3 = this.engine.H(array2, adrs2, list.remove(0).nodeValue, array3);
                adrs2.setTreeHeight(++n3);
            }
            list.add(0, new NodeEntry(array3, n3));
        }
        return ((NodeEntry)list.get(0)).nodeValue;
    }
    
    public SIG_FORS[] sign(final byte[] array, final byte[] array2, final byte[] array3, final ADRS adrs) {
        final ADRS adrs2 = new ADRS(adrs);
        final int[] base2B = base2B(array, this.engine.A, this.engine.K);
        final SIG_FORS[] array4 = new SIG_FORS[this.engine.K];
        for (int i = 0; i < this.engine.K; ++i) {
            final int n = base2B[i];
            adrs2.setTypeAndClear(6);
            adrs2.setKeyPairAddress(adrs.getKeyPairAddress());
            adrs2.setTreeHeight(0);
            adrs2.setTreeIndex((i << this.engine.A) + n);
            final byte[] prf = this.engine.PRF(array3, array2, adrs2);
            adrs2.changeType(3);
            final byte[][] array5 = new byte[this.engine.A][];
            for (int j = 0; j < this.engine.A; ++j) {
                array5[j] = this.treehash(array2, (i << this.engine.A) + ((n >>> j ^ 0x1) << j), j, array3, adrs2);
            }
            array4[i] = new SIG_FORS(prf, array5);
        }
        return array4;
    }
    
    public byte[] pkFromSig(final SIG_FORS[] array, final byte[] array2, final byte[] array3, final ADRS adrs) {
        final byte[][] array4 = new byte[2][];
        final byte[][] array5 = new byte[this.engine.K][];
        final int[] base2B = base2B(array2, this.engine.A, this.engine.K);
        for (int i = 0; i < this.engine.K; ++i) {
            final int n = base2B[i];
            final byte[] sk = array[i].getSK();
            adrs.setTreeHeight(0);
            adrs.setTreeIndex((i << this.engine.A) + n);
            array4[0] = this.engine.F(array3, adrs, sk);
            final byte[][] authPath = array[i].getAuthPath();
            adrs.setTreeIndex((i << this.engine.A) + n);
            for (int j = 0; j < this.engine.A; ++j) {
                adrs.setTreeHeight(j + 1);
                if ((n & 1 << j) == 0x0) {
                    adrs.setTreeIndex(adrs.getTreeIndex() / 2);
                    array4[1] = this.engine.H(array3, adrs, array4[0], authPath[j]);
                }
                else {
                    adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
                    array4[1] = this.engine.H(array3, adrs, authPath[j], array4[0]);
                }
                array4[0] = array4[1];
            }
            array5[i] = array4[0];
        }
        final ADRS adrs2 = new ADRS(adrs);
        adrs2.setTypeAndClear(4);
        adrs2.setKeyPairAddress(adrs.getKeyPairAddress());
        return this.engine.T_l(array3, adrs2, Arrays.concatenate(array5));
    }
    
    static int[] base2B(final byte[] array, final int exponent, final int n) {
        final int[] array2 = new int[n];
        int n2 = 0;
        int i = 0;
        BigInteger bigInteger = BigInteger.ZERO;
        for (int j = 0; j < n; ++j) {
            while (i < exponent) {
                bigInteger = bigInteger.shiftLeft(8).add(BigInteger.valueOf(array[n2] & 0xFF));
                ++n2;
                i += 8;
            }
            i -= exponent;
            array2[j] = bigInteger.shiftRight(i).mod(BigInteger.valueOf(2L).pow(exponent)).intValue();
        }
        return array2;
    }
}
