// 
// Decompiled by Procyon v0.6.0
// 

package com.nimbusds.jose.crypto;

import java.util.List;
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.util.Base64URL;
import java.util.Map;
import com.nimbusds.jose.util.JSONObjectUtils;
import com.nimbusds.jose.crypto.impl.JWEHeaderValidation;
import com.nimbusds.jose.UnprotectedHeader;
import com.nimbusds.jose.util.JSONArrayUtils;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.crypto.impl.AAD;
import com.nimbusds.jose.JWECryptoParts;
import com.nimbusds.jose.JWEHeader;
import java.util.Iterator;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.jwk.JWK;
import javax.crypto.SecretKey;
import com.nimbusds.jose.KeyLengthException;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.shaded.jcip.ThreadSafe;
import com.nimbusds.jose.JWEEncrypter;
import com.nimbusds.jose.crypto.impl.MultiCryptoProvider;

@ThreadSafe
public class MultiEncrypter extends MultiCryptoProvider implements JWEEncrypter
{
    private static final String[] RECIPIENT_HEADER_PARAMS;
    private final JWKSet keys;
    
    public MultiEncrypter(final JWKSet keys) throws KeyLengthException {
        this(keys, findDirectCEK(keys));
    }
    
    public MultiEncrypter(final JWKSet keys, final SecretKey contentEncryptionKey) throws KeyLengthException {
        super(contentEncryptionKey);
        for (final JWK jwk : keys.getKeys()) {
            final KeyType kty = jwk.getKeyType();
            if (jwk.getAlgorithm() == null) {
                throw new IllegalArgumentException("Each JWK must specify a key encryption algorithm");
            }
            final JWEAlgorithm alg = JWEAlgorithm.parse(jwk.getAlgorithm().toString());
            if (JWEAlgorithm.DIR.equals(alg) && KeyType.OCT.equals(kty) && !jwk.toOctetSequenceKey().toSecretKey("AES").equals(contentEncryptionKey)) {
                throw new IllegalArgumentException("Bad CEK");
            }
            if ((!KeyType.RSA.equals(kty) || !RSAEncrypter.SUPPORTED_ALGORITHMS.contains(alg)) && (!KeyType.EC.equals(kty) || !ECDHEncrypter.SUPPORTED_ALGORITHMS.contains(alg)) && (!KeyType.OCT.equals(kty) || !AESEncrypter.SUPPORTED_ALGORITHMS.contains(alg)) && (!KeyType.OCT.equals(kty) || !DirectEncrypter.SUPPORTED_ALGORITHMS.contains(alg)) && (!KeyType.OKP.equals(kty) || !X25519Encrypter.SUPPORTED_ALGORITHMS.contains(alg))) {
                throw new IllegalArgumentException("Unsupported key encryption algorithm: " + alg);
            }
        }
        this.keys = keys;
    }
    
    private static SecretKey findDirectCEK(final JWKSet keys) {
        if (keys != null) {
            for (final JWK jwk : keys.getKeys()) {
                if (JWEAlgorithm.DIR.equals(jwk.getAlgorithm()) && KeyType.OCT.equals(jwk.getKeyType())) {
                    return jwk.toOctetSequenceKey().toSecretKey("AES");
                }
            }
        }
        return null;
    }
    
    @Deprecated
    public JWECryptoParts encrypt(final JWEHeader header, final byte[] clearText) throws JOSEException {
        return this.encrypt(header, clearText, AAD.compute(header));
    }
    
    @Override
    public JWECryptoParts encrypt(final JWEHeader header, final byte[] clearText, final byte[] aad) throws JOSEException {
        if (aad == null) {
            throw new JOSEException("Missing JWE additional authenticated data (AAD)");
        }
        final EncryptionMethod enc = header.getEncryptionMethod();
        final SecretKey cek = this.getCEK(enc);
        JWEHeader recipientHeader = null;
        Base64URL encryptedKey = null;
        Base64URL cipherText = null;
        Base64URL iv = null;
        Base64URL tag = null;
        Payload payload = new Payload(clearText);
        final List<Object> recipients = JSONArrayUtils.newJSONArray();
        for (final JWK key : this.keys.getKeys()) {
            final KeyType kty = key.getKeyType();
            final Map<String, Object> keyMap = key.toJSONObject();
            final UnprotectedHeader.Builder unprotected = new UnprotectedHeader.Builder();
            for (final String param : MultiEncrypter.RECIPIENT_HEADER_PARAMS) {
                if (keyMap.containsKey(param)) {
                    unprotected.param(param, keyMap.get(param));
                }
            }
            try {
                recipientHeader = (JWEHeader)header.join(unprotected.build());
            }
            catch (final Exception e) {
                throw new JOSEException(e.getMessage(), e);
            }
            final JWEAlgorithm alg = JWEHeaderValidation.getAlgorithmAndEnsureNotNull(recipientHeader);
            JWEEncrypter encrypter;
            if (KeyType.RSA.equals(kty) && RSAEncrypter.SUPPORTED_ALGORITHMS.contains(alg)) {
                encrypter = new RSAEncrypter(key.toRSAKey().toRSAPublicKey(), cek);
            }
            else if (KeyType.EC.equals(kty) && ECDHEncrypter.SUPPORTED_ALGORITHMS.contains(alg)) {
                encrypter = new ECDHEncrypter(key.toECKey().toECPublicKey(), cek);
            }
            else if (KeyType.OCT.equals(kty) && AESEncrypter.SUPPORTED_ALGORITHMS.contains(alg)) {
                encrypter = new AESEncrypter(key.toOctetSequenceKey().toSecretKey("AES"), cek);
            }
            else if (KeyType.OCT.equals(kty) && DirectEncrypter.SUPPORTED_ALGORITHMS.contains(alg)) {
                encrypter = new DirectEncrypter(key.toOctetSequenceKey().toSecretKey("AES"));
            }
            else {
                if (!KeyType.OKP.equals(kty) || !X25519Encrypter.SUPPORTED_ALGORITHMS.contains(alg)) {
                    continue;
                }
                encrypter = new X25519Encrypter(key.toOctetKeyPair().toPublicJWK(), cek);
            }
            final JWECryptoParts jweParts = encrypter.encrypt(recipientHeader, payload.toBytes(), aad);
            final Map<String, Object> recipientHeaderMap = jweParts.getHeader().toJSONObject();
            for (final String param2 : header.getIncludedParams()) {
                recipientHeaderMap.remove(param2);
            }
            final Map<String, Object> recipient = JSONObjectUtils.newJSONObject();
            recipient.put("header", recipientHeaderMap);
            if (!JWEAlgorithm.DIR.equals(alg)) {
                recipient.put("encrypted_key", jweParts.getEncryptedKey().toString());
            }
            recipients.add(recipient);
            if (recipients.size() == 1) {
                payload = new Payload("");
                encryptedKey = jweParts.getEncryptedKey();
                iv = jweParts.getInitializationVector();
                cipherText = jweParts.getCipherText();
                tag = jweParts.getAuthenticationTag();
            }
        }
        if (recipients.size() > 1) {
            final Map<String, Object> jweJsonObject = JSONObjectUtils.newJSONObject();
            jweJsonObject.put("recipients", recipients);
            encryptedKey = Base64URL.encode(JSONObjectUtils.toJSONString(jweJsonObject));
        }
        return new JWECryptoParts(header, encryptedKey, iv, cipherText, tag);
    }
    
    static {
        RECIPIENT_HEADER_PARAMS = new String[] { "kid", "alg", "x5u", "x5t", "x5t#S256", "x5c" };
    }
}
