// 
// Decompiled by Procyon v0.6.0
// 

package com.google.crypto.tink.subtle;

import java.security.MessageDigest;
import java.math.BigInteger;
import java.security.Key;
import javax.crypto.Cipher;
import java.security.spec.RSAPublicKeySpec;
import java.security.interfaces.RSAPublicKey;
import com.google.crypto.tink.util.SecretBigInteger;
import com.google.crypto.tink.signature.RsaSsaPssPublicKey;
import com.google.crypto.tink.AccessesPartialKey;
import java.security.GeneralSecurityException;
import java.security.NoSuchProviderException;
import com.google.crypto.tink.signature.RsaSsaPssParameters;
import java.security.spec.KeySpec;
import java.security.spec.RSAPrivateCrtKeySpec;
import com.google.crypto.tink.InsecureSecretKeyAccess;
import java.security.interfaces.RSAPrivateCrtKey;
import java.security.KeyFactory;
import com.google.crypto.tink.signature.internal.RsaSsaPssSignConscrypt;
import com.google.crypto.tink.signature.RsaSsaPssPrivateKey;
import com.google.crypto.tink.config.internal.TinkFipsUtil;
import com.google.errorprone.annotations.Immutable;
import com.google.crypto.tink.PublicKeySign;

@Immutable
public final class RsaSsaPssSignJce implements PublicKeySign
{
    public static final TinkFipsUtil.AlgorithmFipsCompatibility FIPS;
    private static final byte[] EMPTY;
    private static final byte[] legacyMessageSuffix;
    private final PublicKeySign sign;
    
    @AccessesPartialKey
    public static PublicKeySign create(final RsaSsaPssPrivateKey key) throws GeneralSecurityException {
        try {
            return RsaSsaPssSignConscrypt.create(key);
        }
        catch (final NoSuchProviderException ex) {
            final KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("RSA");
            final RSAPrivateCrtKey privateKey = (RSAPrivateCrtKey)kf.generatePrivate(new RSAPrivateCrtKeySpec(key.getPublicKey().getModulus(), key.getParameters().getPublicExponent(), key.getPrivateExponent().getBigInteger(InsecureSecretKeyAccess.get()), key.getPrimeP().getBigInteger(InsecureSecretKeyAccess.get()), key.getPrimeQ().getBigInteger(InsecureSecretKeyAccess.get()), key.getPrimeExponentP().getBigInteger(InsecureSecretKeyAccess.get()), key.getPrimeExponentQ().getBigInteger(InsecureSecretKeyAccess.get()), key.getCrtCoefficient().getBigInteger(InsecureSecretKeyAccess.get())));
            final RsaSsaPssParameters params = key.getParameters();
            return new InternalImpl(privateKey, (Enums.HashType)RsaSsaPssVerifyJce.HASH_TYPE_CONVERTER.toProtoEnum(params.getSigHashType()), (Enums.HashType)RsaSsaPssVerifyJce.HASH_TYPE_CONVERTER.toProtoEnum(params.getMgf1HashType()), params.getSaltLengthBytes(), key.getOutputPrefix().toByteArray(), key.getParameters().getVariant().equals(RsaSsaPssParameters.Variant.LEGACY) ? RsaSsaPssSignJce.legacyMessageSuffix : RsaSsaPssSignJce.EMPTY);
        }
    }
    
    private static RsaSsaPssParameters.HashType getHashType(final Enums.HashType hash) throws GeneralSecurityException {
        switch (hash) {
            case SHA256: {
                return RsaSsaPssParameters.HashType.SHA256;
            }
            case SHA384: {
                return RsaSsaPssParameters.HashType.SHA384;
            }
            case SHA512: {
                return RsaSsaPssParameters.HashType.SHA512;
            }
            default: {
                throw new GeneralSecurityException("Unsupported hash: " + hash);
            }
        }
    }
    
    @AccessesPartialKey
    private RsaSsaPssPrivateKey convertKey(final RSAPrivateCrtKey key, final Enums.HashType sigHash, final Enums.HashType mgf1Hash, final int saltLength) throws GeneralSecurityException {
        final RsaSsaPssParameters parameters = RsaSsaPssParameters.builder().setModulusSizeBits(key.getModulus().bitLength()).setPublicExponent(key.getPublicExponent()).setSigHashType(getHashType(sigHash)).setMgf1HashType(getHashType(mgf1Hash)).setSaltLengthBytes(saltLength).setVariant(RsaSsaPssParameters.Variant.NO_PREFIX).build();
        return RsaSsaPssPrivateKey.builder().setPublicKey(RsaSsaPssPublicKey.builder().setParameters(parameters).setModulus(key.getModulus()).build()).setPrimes(SecretBigInteger.fromBigInteger(key.getPrimeP(), InsecureSecretKeyAccess.get()), SecretBigInteger.fromBigInteger(key.getPrimeQ(), InsecureSecretKeyAccess.get())).setPrivateExponent(SecretBigInteger.fromBigInteger(key.getPrivateExponent(), InsecureSecretKeyAccess.get())).setPrimeExponents(SecretBigInteger.fromBigInteger(key.getPrimeExponentP(), InsecureSecretKeyAccess.get()), SecretBigInteger.fromBigInteger(key.getPrimeExponentQ(), InsecureSecretKeyAccess.get())).setCrtCoefficient(SecretBigInteger.fromBigInteger(key.getCrtCoefficient(), InsecureSecretKeyAccess.get())).build();
    }
    
    public RsaSsaPssSignJce(final RSAPrivateCrtKey priv, final Enums.HashType sigHash, final Enums.HashType mgf1Hash, final int saltLength) throws GeneralSecurityException {
        this.sign = create(this.convertKey(priv, sigHash, mgf1Hash, saltLength));
    }
    
    @Override
    public byte[] sign(final byte[] data) throws GeneralSecurityException {
        return this.sign.sign(data);
    }
    
    static {
        FIPS = TinkFipsUtil.AlgorithmFipsCompatibility.ALGORITHM_REQUIRES_BORINGCRYPTO;
        EMPTY = new byte[0];
        legacyMessageSuffix = new byte[] { 0 };
    }
    
    private static final class InternalImpl implements PublicKeySign
    {
        private final RSAPrivateCrtKey privateKey;
        private final RSAPublicKey publicKey;
        private final Enums.HashType sigHash;
        private final Enums.HashType mgf1Hash;
        private final int saltLength;
        private final byte[] outputPrefix;
        private final byte[] messageSuffix;
        private static final String RAW_RSA_ALGORITHM = "RSA/ECB/NOPADDING";
        
        private InternalImpl(final RSAPrivateCrtKey priv, final Enums.HashType sigHash, final Enums.HashType mgf1Hash, final int saltLength, final byte[] outputPrefix, final byte[] messageSuffix) throws GeneralSecurityException {
            if (TinkFipsUtil.useOnlyFips()) {
                throw new GeneralSecurityException("Can not use RSA PSS in FIPS-mode, as BoringCrypto module is not available.");
            }
            Validators.validateSignatureHash(sigHash);
            if (!sigHash.equals(mgf1Hash)) {
                throw new GeneralSecurityException("sigHash and mgf1Hash must be the same");
            }
            Validators.validateRsaModulusSize(priv.getModulus().bitLength());
            Validators.validateRsaPublicExponent(priv.getPublicExponent());
            this.privateKey = priv;
            final KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("RSA");
            this.publicKey = (RSAPublicKey)kf.generatePublic(new RSAPublicKeySpec(priv.getModulus(), priv.getPublicExponent()));
            this.sigHash = sigHash;
            this.mgf1Hash = mgf1Hash;
            this.saltLength = saltLength;
            this.outputPrefix = outputPrefix;
            this.messageSuffix = messageSuffix;
        }
        
        private byte[] noPrefixSign(final byte[] data) throws GeneralSecurityException {
            final int modBits = this.publicKey.getModulus().bitLength();
            final byte[] em = this.emsaPssEncode(data, modBits - 1);
            return this.rsasp1(em);
        }
        
        @Override
        public byte[] sign(final byte[] data) throws GeneralSecurityException {
            final byte[] signature = this.noPrefixSign(data);
            if (this.outputPrefix.length == 0) {
                return signature;
            }
            return Bytes.concat(new byte[][] { this.outputPrefix, signature });
        }
        
        private byte[] rsasp1(final byte[] m) throws GeneralSecurityException {
            final Cipher decryptCipher = EngineFactory.CIPHER.getInstance("RSA/ECB/NOPADDING");
            decryptCipher.init(2, this.privateKey);
            final byte[] c = decryptCipher.doFinal(m);
            final Cipher encryptCipher = EngineFactory.CIPHER.getInstance("RSA/ECB/NOPADDING");
            encryptCipher.init(1, this.publicKey);
            final byte[] m2 = encryptCipher.doFinal(c);
            if (!new BigInteger(1, m).equals(new BigInteger(1, m2))) {
                throw new IllegalStateException("Security bug: RSA signature computation error");
            }
            return c;
        }
        
        private byte[] emsaPssEncode(final byte[] message, final int emBits) throws GeneralSecurityException {
            Validators.validateSignatureHash(this.sigHash);
            final MessageDigest digest = EngineFactory.MESSAGE_DIGEST.getInstance(SubtleUtil.toDigestAlgo(this.sigHash));
            digest.update(message);
            if (this.messageSuffix.length != 0) {
                digest.update(this.messageSuffix);
            }
            final byte[] mHash = digest.digest();
            final int hLen = digest.getDigestLength();
            final int emLen = (emBits - 1) / 8 + 1;
            if (emLen < hLen + this.saltLength + 2) {
                throw new GeneralSecurityException("encoding error");
            }
            final byte[] salt = Random.randBytes(this.saltLength);
            final byte[] mPrime = new byte[8 + hLen + this.saltLength];
            System.arraycopy(mHash, 0, mPrime, 8, hLen);
            System.arraycopy(salt, 0, mPrime, 8 + hLen, salt.length);
            final byte[] h = digest.digest(mPrime);
            final byte[] db = new byte[emLen - hLen - 1];
            db[emLen - this.saltLength - hLen - 2] = 1;
            System.arraycopy(salt, 0, db, emLen - this.saltLength - hLen - 1, salt.length);
            final byte[] dbMask = SubtleUtil.mgf1(h, emLen - hLen - 1, this.mgf1Hash);
            final byte[] maskedDb = new byte[emLen - hLen - 1];
            for (int i = 0; i < maskedDb.length; ++i) {
                maskedDb[i] = (byte)(db[i] ^ dbMask[i]);
            }
            for (int i = 0; i < emLen * 8L - emBits; ++i) {
                final int bytePos = i / 8;
                final int bitPos = 7 - i % 8;
                maskedDb[bytePos] &= (byte)~(1 << bitPos);
            }
            final byte[] em = new byte[maskedDb.length + hLen + 1];
            System.arraycopy(maskedDb, 0, em, 0, maskedDb.length);
            System.arraycopy(h, 0, em, maskedDb.length, h.length);
            em[maskedDb.length + hLen] = -68;
            return em;
        }
    }
}
