// 
// Decompiled by Procyon v0.6.0
// 

package org.bouncycastle.pqc.crypto.slhdsa;

import org.bouncycastle.crypto.digests.SHAKEDigest;
import org.bouncycastle.crypto.Xof;
import org.bouncycastle.util.Bytes;
import org.bouncycastle.crypto.DerivationParameters;
import org.bouncycastle.crypto.params.MGFParameters;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.util.Pack;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.util.Memoable;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.generators.MGF1BytesGenerator;
import org.bouncycastle.crypto.macs.HMac;

abstract class SLHDSAEngine
{
    final int N;
    final int WOTS_W;
    final int WOTS_LOGW;
    final int WOTS_LEN;
    final int WOTS_LEN1;
    final int WOTS_LEN2;
    final int D;
    final int A;
    final int K;
    final int H;
    final int H_PRIME;
    
    public SLHDSAEngine(final int n, final int wots_W, final int d, final int a, final int k, final int h) {
        this.N = n;
        if (wots_W == 16) {
            this.WOTS_LOGW = 4;
            this.WOTS_LEN1 = 8 * this.N / this.WOTS_LOGW;
            if (this.N <= 8) {
                this.WOTS_LEN2 = 2;
            }
            else if (this.N <= 136) {
                this.WOTS_LEN2 = 3;
            }
            else {
                if (this.N > 256) {
                    throw new IllegalArgumentException("cannot precompute SPX_WOTS_LEN2 for n outside {2, .., 256}");
                }
                this.WOTS_LEN2 = 4;
            }
        }
        else {
            if (wots_W != 256) {
                throw new IllegalArgumentException("wots_w assumed 16 or 256");
            }
            this.WOTS_LOGW = 8;
            this.WOTS_LEN1 = 8 * this.N / this.WOTS_LOGW;
            if (this.N <= 1) {
                this.WOTS_LEN2 = 1;
            }
            else {
                if (this.N > 256) {
                    throw new IllegalArgumentException("cannot precompute SPX_WOTS_LEN2 for n outside {2, .., 256}");
                }
                this.WOTS_LEN2 = 2;
            }
        }
        this.WOTS_W = wots_W;
        this.WOTS_LEN = this.WOTS_LEN1 + this.WOTS_LEN2;
        this.D = d;
        this.A = a;
        this.K = k;
        this.H = h;
        this.H_PRIME = h / d;
    }
    
    abstract void init(final byte[] p0);
    
    abstract byte[] F(final byte[] p0, final ADRS p1, final byte[] p2);
    
    abstract byte[] H(final byte[] p0, final ADRS p1, final byte[] p2, final byte[] p3);
    
    abstract IndexedDigest H_msg(final byte[] p0, final byte[] p1, final byte[] p2, final byte[] p3, final byte[] p4);
    
    abstract byte[] T_l(final byte[] p0, final ADRS p1, final byte[] p2);
    
    abstract byte[] PRF(final byte[] p0, final byte[] p1, final ADRS p2);
    
    abstract byte[] PRF_msg(final byte[] p0, final byte[] p1, final byte[] p2, final byte[] p3);
    
    static class Sha2Engine extends SLHDSAEngine
    {
        private final HMac treeHMac;
        private final MGF1BytesGenerator mgf1;
        private final byte[] hmacBuf;
        private final Digest msgDigest;
        private final byte[] msgDigestBuf;
        private final int bl;
        private final Digest sha256;
        private final byte[] sha256Buf;
        private Memoable msgMemo;
        private Memoable sha256Memo;
        
        public Sha2Engine(final int n, final int n2, final int n3, final int n4, final int n5, final int n6) {
            super(n, n2, n3, n4, n5, n6);
            this.sha256 = new SHA256Digest();
            this.sha256Buf = new byte[this.sha256.getDigestSize()];
            if (n == 16) {
                this.msgDigest = new SHA256Digest();
                this.treeHMac = new HMac(new SHA256Digest());
                this.mgf1 = new MGF1BytesGenerator(new SHA256Digest());
                this.bl = 64;
            }
            else {
                this.msgDigest = new SHA512Digest();
                this.treeHMac = new HMac(new SHA512Digest());
                this.mgf1 = new MGF1BytesGenerator(new SHA512Digest());
                this.bl = 128;
            }
            this.hmacBuf = new byte[this.treeHMac.getMacSize()];
            this.msgDigestBuf = new byte[this.msgDigest.getDigestSize()];
        }
        
        @Override
        void init(final byte[] array) {
            final byte[] array2 = new byte[this.bl];
            this.msgDigest.update(array, 0, array.length);
            this.msgDigest.update(array2, 0, this.bl - this.N);
            this.msgMemo = ((Memoable)this.msgDigest).copy();
            this.msgDigest.reset();
            this.sha256.update(array, 0, array.length);
            this.sha256.update(array2, 0, 64 - array.length);
            this.sha256Memo = ((Memoable)this.sha256).copy();
            this.sha256.reset();
        }
        
        public byte[] F(final byte[] array, final ADRS adrs, final byte[] array2) {
            final byte[] compressedADRS = this.compressedADRS(adrs);
            ((Memoable)this.sha256).reset(this.sha256Memo);
            this.sha256.update(compressedADRS, 0, compressedADRS.length);
            this.sha256.update(array2, 0, array2.length);
            this.sha256.doFinal(this.sha256Buf, 0);
            return Arrays.copyOfRange(this.sha256Buf, 0, this.N);
        }
        
        public byte[] H(final byte[] array, final ADRS adrs, final byte[] array2, final byte[] array3) {
            final byte[] compressedADRS = this.compressedADRS(adrs);
            ((Memoable)this.msgDigest).reset(this.msgMemo);
            this.msgDigest.update(compressedADRS, 0, compressedADRS.length);
            this.msgDigest.update(array2, 0, array2.length);
            this.msgDigest.update(array3, 0, array3.length);
            this.msgDigest.doFinal(this.msgDigestBuf, 0);
            return Arrays.copyOfRange(this.msgDigestBuf, 0, this.N);
        }
        
        @Override
        IndexedDigest H_msg(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4, final byte[] array5) {
            final int n = (this.A * this.K + 7) / 8;
            final int n2 = this.H / this.D;
            final int n3 = this.H - n2;
            final int n4 = (n2 + 7) / 8;
            final int n5 = (n3 + 7) / 8;
            final byte[] array6 = new byte[n + n4 + n5];
            final byte[] array7 = new byte[this.msgDigest.getDigestSize()];
            this.msgDigest.update(array, 0, array.length);
            this.msgDigest.update(array2, 0, array2.length);
            this.msgDigest.update(array3, 0, array3.length);
            if (array4 != null) {
                this.msgDigest.update(array4, 0, array4.length);
            }
            this.msgDigest.update(array5, 0, array5.length);
            this.msgDigest.doFinal(array7, 0);
            final byte[] bitmask = this.bitmask(Arrays.concatenate(array, array2, array7), array6);
            final byte[] array8 = new byte[8];
            System.arraycopy(bitmask, n, array8, 8 - n5, n5);
            final long n6 = Pack.bigEndianToLong(array8, 0) & -1L >>> 64 - n3;
            final byte[] array9 = new byte[4];
            System.arraycopy(bitmask, n + n5, array9, 4 - n4, n4);
            return new IndexedDigest(n6, Pack.bigEndianToInt(array9, 0) & -1 >>> 32 - n2, Arrays.copyOfRange(bitmask, 0, n));
        }
        
        public byte[] T_l(final byte[] array, final ADRS adrs, final byte[] array2) {
            final byte[] compressedADRS = this.compressedADRS(adrs);
            ((Memoable)this.msgDigest).reset(this.msgMemo);
            this.msgDigest.update(compressedADRS, 0, compressedADRS.length);
            this.msgDigest.update(array2, 0, array2.length);
            this.msgDigest.doFinal(this.msgDigestBuf, 0);
            return Arrays.copyOfRange(this.msgDigestBuf, 0, this.N);
        }
        
        @Override
        byte[] PRF(final byte[] array, final byte[] array2, final ADRS adrs) {
            final int length = array2.length;
            ((Memoable)this.sha256).reset(this.sha256Memo);
            final byte[] compressedADRS = this.compressedADRS(adrs);
            this.sha256.update(compressedADRS, 0, compressedADRS.length);
            this.sha256.update(array2, 0, array2.length);
            this.sha256.doFinal(this.sha256Buf, 0);
            return Arrays.copyOfRange(this.sha256Buf, 0, length);
        }
        
        public byte[] PRF_msg(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4) {
            this.treeHMac.init(new KeyParameter(array));
            this.treeHMac.update(array2, 0, array2.length);
            if (array3 != null) {
                this.treeHMac.update(array3, 0, array3.length);
            }
            this.treeHMac.update(array4, 0, array4.length);
            this.treeHMac.doFinal(this.hmacBuf, 0);
            return Arrays.copyOfRange(this.hmacBuf, 0, this.N);
        }
        
        private byte[] compressedADRS(final ADRS adrs) {
            final byte[] array = new byte[22];
            System.arraycopy(adrs.value, 3, array, 0, 1);
            System.arraycopy(adrs.value, 8, array, 1, 8);
            System.arraycopy(adrs.value, 19, array, 9, 1);
            System.arraycopy(adrs.value, 20, array, 10, 12);
            return array;
        }
        
        protected byte[] bitmask(final byte[] array, final byte[] array2) {
            final byte[] array3 = new byte[array2.length];
            this.mgf1.init(new MGFParameters(array));
            this.mgf1.generateBytes(array3, 0, array3.length);
            Bytes.xorTo(array2.length, array2, array3);
            return array3;
        }
        
        protected byte[] bitmask(final byte[] array, final byte[] array2, final byte[] array3) {
            final byte[] array4 = new byte[array2.length + array3.length];
            this.mgf1.init(new MGFParameters(array));
            this.mgf1.generateBytes(array4, 0, array4.length);
            Bytes.xorTo(array2.length, array2, array4);
            Bytes.xorTo(array3.length, array3, 0, array4, array2.length);
            return array4;
        }
        
        protected byte[] bitmask256(final byte[] array, final byte[] array2) {
            final byte[] array3 = new byte[array2.length];
            final MGF1BytesGenerator mgf1BytesGenerator = new MGF1BytesGenerator(new SHA256Digest());
            mgf1BytesGenerator.init(new MGFParameters(array));
            mgf1BytesGenerator.generateBytes(array3, 0, array3.length);
            Bytes.xorTo(array2.length, array2, array3);
            return array3;
        }
    }
    
    static class Shake256Engine extends SLHDSAEngine
    {
        private final Xof treeDigest;
        private final Xof maskDigest;
        
        public Shake256Engine(final int n, final int n2, final int n3, final int n4, final int n5, final int n6) {
            super(n, n2, n3, n4, n5, n6);
            this.treeDigest = new SHAKEDigest(256);
            this.maskDigest = new SHAKEDigest(256);
        }
        
        @Override
        void init(final byte[] array) {
        }
        
        @Override
        byte[] F(final byte[] array, final ADRS adrs, final byte[] array2) {
            final byte[] array3 = new byte[this.N];
            this.treeDigest.update(array, 0, array.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            this.treeDigest.update(array2, 0, array2.length);
            this.treeDigest.doFinal(array3, 0, array3.length);
            return array3;
        }
        
        @Override
        byte[] H(final byte[] array, final ADRS adrs, final byte[] array2, final byte[] array3) {
            final byte[] array4 = new byte[this.N];
            this.treeDigest.update(array, 0, array.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            this.treeDigest.update(array2, 0, array2.length);
            this.treeDigest.update(array3, 0, array3.length);
            this.treeDigest.doFinal(array4, 0, array4.length);
            return array4;
        }
        
        @Override
        IndexedDigest H_msg(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4, final byte[] array5) {
            final int n = (this.A * this.K + 7) / 8;
            final int n2 = this.H / this.D;
            final int n3 = this.H - n2;
            final int n4 = (n2 + 7) / 8;
            final int n5 = (n3 + 7) / 8;
            final byte[] array6 = new byte[n + n4 + n5];
            this.treeDigest.update(array, 0, array.length);
            this.treeDigest.update(array2, 0, array2.length);
            this.treeDigest.update(array3, 0, array3.length);
            if (array4 != null) {
                this.treeDigest.update(array4, 0, array4.length);
            }
            this.treeDigest.update(array5, 0, array5.length);
            this.treeDigest.doFinal(array6, 0, array6.length);
            final byte[] array7 = new byte[8];
            System.arraycopy(array6, n, array7, 8 - n5, n5);
            final long n6 = Pack.bigEndianToLong(array7, 0) & -1L >>> 64 - n3;
            final byte[] array8 = new byte[4];
            System.arraycopy(array6, n + n5, array8, 4 - n4, n4);
            return new IndexedDigest(n6, Pack.bigEndianToInt(array8, 0) & -1 >>> 32 - n2, Arrays.copyOfRange(array6, 0, n));
        }
        
        @Override
        byte[] T_l(final byte[] array, final ADRS adrs, final byte[] array2) {
            final byte[] array3 = new byte[this.N];
            this.treeDigest.update(array, 0, array.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            this.treeDigest.update(array2, 0, array2.length);
            this.treeDigest.doFinal(array3, 0, array3.length);
            return array3;
        }
        
        @Override
        byte[] PRF(final byte[] array, final byte[] array2, final ADRS adrs) {
            this.treeDigest.update(array, 0, array.length);
            this.treeDigest.update(adrs.value, 0, adrs.value.length);
            this.treeDigest.update(array2, 0, array2.length);
            final byte[] array3 = new byte[this.N];
            this.treeDigest.doFinal(array3, 0, this.N);
            return array3;
        }
        
        public byte[] PRF_msg(final byte[] array, final byte[] array2, final byte[] array3, final byte[] array4) {
            this.treeDigest.update(array, 0, array.length);
            this.treeDigest.update(array2, 0, array2.length);
            if (array3 != null) {
                this.treeDigest.update(array3, 0, array3.length);
            }
            this.treeDigest.update(array4, 0, array4.length);
            final byte[] array5 = new byte[this.N];
            this.treeDigest.doFinal(array5, 0, array5.length);
            return array5;
        }
        
        protected byte[] bitmask(final byte[] array, final ADRS adrs, final byte[] array2) {
            final byte[] array3 = new byte[array2.length];
            this.maskDigest.update(array, 0, array.length);
            this.maskDigest.update(adrs.value, 0, adrs.value.length);
            this.maskDigest.doFinal(array3, 0, array3.length);
            Bytes.xorTo(array2.length, array2, array3);
            return array3;
        }
        
        protected byte[] bitmask(final byte[] array, final ADRS adrs, final byte[] array2, final byte[] array3) {
            final byte[] array4 = new byte[array2.length + array3.length];
            this.maskDigest.update(array, 0, array.length);
            this.maskDigest.update(adrs.value, 0, adrs.value.length);
            this.maskDigest.doFinal(array4, 0, array4.length);
            Bytes.xorTo(array2.length, array2, array4);
            Bytes.xorTo(array3.length, array3, 0, array4, array2.length);
            return array4;
        }
    }
}
