// 
// Decompiled by Procyon v0.6.0
// 

package com.nimbusds.jose.crypto.impl;

import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jose.crypto.utils.ECChecks;
import com.nimbusds.jose.jwk.OctetKeyPair;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.ECPrivateKey;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.util.ByteUtils;
import com.nimbusds.jose.util.Base64URL;
import java.nio.charset.StandardCharsets;
import javax.crypto.SecretKey;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import java.util.Collection;
import java.util.Objects;
import com.nimbusds.jose.JWEAlgorithm;

public class ECDH1PU
{
    public static ECDH.AlgorithmMode resolveAlgorithmMode(final JWEAlgorithm alg) throws JOSEException {
        Objects.requireNonNull(alg, "The parameter \"alg\" must not be null");
        if (alg.equals(JWEAlgorithm.ECDH_1PU)) {
            return ECDH.AlgorithmMode.DIRECT;
        }
        if (alg.equals(JWEAlgorithm.ECDH_1PU_A128KW) || alg.equals(JWEAlgorithm.ECDH_1PU_A192KW) || alg.equals(JWEAlgorithm.ECDH_1PU_A256KW)) {
            return ECDH.AlgorithmMode.KW;
        }
        throw new JOSEException(AlgorithmSupportMessage.unsupportedJWEAlgorithm(alg, ECDHCryptoProvider.SUPPORTED_ALGORITHMS));
    }
    
    public static int sharedKeyLength(final JWEAlgorithm alg, final EncryptionMethod enc) throws JOSEException {
        Objects.requireNonNull(alg, "The parameter \"alg\" must not be null");
        Objects.requireNonNull(enc, "The parameter \"enc\" must not be null");
        if (alg.equals(JWEAlgorithm.ECDH_1PU)) {
            final int length = enc.cekBitLength();
            if (length == 0) {
                throw new JOSEException("Unsupported JWE encryption method " + enc);
            }
            return length;
        }
        else {
            if (alg.equals(JWEAlgorithm.ECDH_1PU_A128KW)) {
                return 128;
            }
            if (alg.equals(JWEAlgorithm.ECDH_1PU_A192KW)) {
                return 192;
            }
            if (alg.equals(JWEAlgorithm.ECDH_1PU_A256KW)) {
                return 256;
            }
            throw new JOSEException(AlgorithmSupportMessage.unsupportedJWEAlgorithm(alg, ECDHCryptoProvider.SUPPORTED_ALGORITHMS));
        }
    }
    
    public static SecretKey deriveSharedKey(final JWEHeader header, final SecretKey Z, final ConcatKDF concatKDF) throws JOSEException {
        Objects.requireNonNull(header, "The parameter \"header\" must not be null");
        Objects.requireNonNull(Z, "The parameter \"Z\" must not be null");
        Objects.requireNonNull(concatKDF, "The parameter \"concatKDF\" must not be null");
        final int sharedKeyLength = sharedKeyLength(header.getAlgorithm(), header.getEncryptionMethod());
        final ECDH.AlgorithmMode algMode = resolveAlgorithmMode(header.getAlgorithm());
        String algID;
        if (algMode == ECDH.AlgorithmMode.DIRECT) {
            algID = header.getEncryptionMethod().getName();
        }
        else {
            if (algMode != ECDH.AlgorithmMode.KW) {
                throw new JOSEException("Unsupported JWE ECDH algorithm mode: " + algMode);
            }
            algID = header.getAlgorithm().getName();
        }
        return concatKDF.deriveKey(Z, sharedKeyLength, ConcatKDF.encodeDataWithLength(algID.getBytes(StandardCharsets.US_ASCII)), ConcatKDF.encodeDataWithLength(header.getAgreementPartyUInfo()), ConcatKDF.encodeDataWithLength(header.getAgreementPartyVInfo()), ConcatKDF.encodeIntData(sharedKeyLength), ConcatKDF.encodeNoData());
    }
    
    public static SecretKey deriveSharedKey(final JWEHeader header, final SecretKey Z, final Base64URL tag, final ConcatKDF concatKDF) throws JOSEException {
        Objects.requireNonNull(header, "The parameter \"header\" must not be null");
        Objects.requireNonNull(Z, "The parameter \"Z\" must not be null");
        Objects.requireNonNull(tag, "The parameter \"tag\" must not be null");
        Objects.requireNonNull(concatKDF, "The parameter \"concatKDF\" must not be null");
        final int sharedKeyLength = sharedKeyLength(header.getAlgorithm(), header.getEncryptionMethod());
        final ECDH.AlgorithmMode algMode = resolveAlgorithmMode(header.getAlgorithm());
        String algID;
        if (algMode == ECDH.AlgorithmMode.DIRECT) {
            algID = header.getEncryptionMethod().getName();
        }
        else {
            if (algMode != ECDH.AlgorithmMode.KW) {
                throw new JOSEException("Unsupported JWE ECDH algorithm mode: " + algMode);
            }
            algID = header.getAlgorithm().getName();
        }
        return concatKDF.deriveKey(Z, sharedKeyLength, ConcatKDF.encodeDataWithLength(algID.getBytes(StandardCharsets.US_ASCII)), ConcatKDF.encodeDataWithLength(header.getAgreementPartyUInfo()), ConcatKDF.encodeDataWithLength(header.getAgreementPartyVInfo()), ConcatKDF.encodeIntData(sharedKeyLength), ConcatKDF.encodeNoData(), ConcatKDF.encodeDataWithLength(tag));
    }
    
    public static SecretKey deriveZ(final SecretKey Ze, final SecretKey Zs) {
        Objects.requireNonNull(Ze, "The parameter \"Ze\" must not be null");
        Objects.requireNonNull(Zs, "The parameter \"Zs\" must not be null");
        final byte[] encodedKey = ByteUtils.concat(new byte[][] { Ze.getEncoded(), Zs.getEncoded() });
        return new SecretKeySpec(encodedKey, 0, encodedKey.length, "AES");
    }
    
    public static SecretKey deriveSenderZ(final ECPrivateKey privateKey, final ECPublicKey publicKey, final ECPrivateKey epk, final Provider provider) throws JOSEException {
        validateSameCurve(privateKey, publicKey);
        validateSameCurve(epk, publicKey);
        final SecretKey Ze = ECDH.deriveSharedSecret(publicKey, epk, provider);
        final SecretKey Zs = ECDH.deriveSharedSecret(publicKey, privateKey, provider);
        return deriveZ(Ze, Zs);
    }
    
    public static SecretKey deriveSenderZ(final OctetKeyPair privateKey, final OctetKeyPair publicKey, final OctetKeyPair epk) throws JOSEException {
        validateSameCurve(privateKey, publicKey);
        validateSameCurve(epk, publicKey);
        final SecretKey Ze = ECDH.deriveSharedSecret(publicKey, epk);
        final SecretKey Zs = ECDH.deriveSharedSecret(publicKey, privateKey);
        return deriveZ(Ze, Zs);
    }
    
    public static SecretKey deriveRecipientZ(final ECPrivateKey privateKey, final ECPublicKey publicKey, final ECPublicKey epk, final Provider provider) throws JOSEException {
        validateSameCurve(privateKey, publicKey);
        validateSameCurve(privateKey, epk);
        final SecretKey Ze = ECDH.deriveSharedSecret(epk, privateKey, provider);
        final SecretKey Zs = ECDH.deriveSharedSecret(publicKey, privateKey, provider);
        return deriveZ(Ze, Zs);
    }
    
    public static SecretKey deriveRecipientZ(final OctetKeyPair privateKey, final OctetKeyPair publicKey, final OctetKeyPair epk) throws JOSEException {
        validateSameCurve(privateKey, publicKey);
        validateSameCurve(privateKey, epk);
        final SecretKey Ze = ECDH.deriveSharedSecret(epk, privateKey);
        final SecretKey Zs = ECDH.deriveSharedSecret(publicKey, privateKey);
        return deriveZ(Ze, Zs);
    }
    
    public static void validateSameCurve(final ECPrivateKey privateKey, final ECPublicKey publicKey) throws JOSEException {
        Objects.requireNonNull(privateKey, "The parameter \"privateKey\" must not be null");
        Objects.requireNonNull(publicKey, "The parameter \"publicKey\" must not be null");
        if (!privateKey.getParams().getCurve().equals(publicKey.getParams().getCurve())) {
            throw new JOSEException("Curve of public key does not match curve of private key");
        }
        if (!ECChecks.isPointOnCurve(publicKey, privateKey)) {
            throw new JOSEException("Invalid public EC key: Point(s) not on the expected curve");
        }
    }
    
    public static void validateSameCurve(final OctetKeyPair privateKey, final OctetKeyPair publicKey) throws JOSEException {
        Objects.requireNonNull(privateKey, "The parameter \"privateKey\" must not be null");
        Objects.requireNonNull(publicKey, "The parameter \"publicKey\" must not be null");
        if (!privateKey.isPrivate()) {
            throw new JOSEException("OKP private key should be a private key");
        }
        if (publicKey.isPrivate()) {
            throw new JOSEException("OKP public key should not be a private key");
        }
        if (!publicKey.getCurve().equals(Curve.X25519)) {
            throw new JOSEException("Only supports OctetKeyPairs with crv=X25519");
        }
        if (!privateKey.getCurve().equals(publicKey.getCurve())) {
            throw new JOSEException("Curve of public key does not match curve of private key");
        }
    }
    
    private ECDH1PU() {
    }
}
