// 
// Decompiled by Procyon v0.6.0
// 

package com.google.crypto.tink.hybrid.internal;

import java.security.interfaces.ECPublicKey;
import java.security.interfaces.ECPrivateKey;
import java.security.KeyPair;
import com.google.crypto.tink.subtle.Bytes;
import java.security.GeneralSecurityException;
import com.google.crypto.tink.subtle.EllipticCurves;
import com.google.errorprone.annotations.Immutable;

@Immutable
final class NistCurvesHpkeKem implements HpkeKem
{
    private final EllipticCurves.CurveType curve;
    private final HkdfHpkeKdf hkdf;
    
    static NistCurvesHpkeKem fromCurve(final EllipticCurves.CurveType curve) throws GeneralSecurityException {
        switch (curve) {
            case NIST_P256: {
                return new NistCurvesHpkeKem(new HkdfHpkeKdf("HmacSha256"), EllipticCurves.CurveType.NIST_P256);
            }
            case NIST_P384: {
                return new NistCurvesHpkeKem(new HkdfHpkeKdf("HmacSha384"), EllipticCurves.CurveType.NIST_P384);
            }
            case NIST_P521: {
                return new NistCurvesHpkeKem(new HkdfHpkeKdf("HmacSha512"), EllipticCurves.CurveType.NIST_P521);
            }
            default: {
                throw new GeneralSecurityException("invalid curve type: " + curve);
            }
        }
    }
    
    private NistCurvesHpkeKem(final HkdfHpkeKdf hkdf, final EllipticCurves.CurveType curve) {
        this.hkdf = hkdf;
        this.curve = curve;
    }
    
    byte[] deriveKemSharedSecret(final byte[] dhSharedSecret, final byte[] senderEphemeralPublicKey, final byte[] recipientPublicKey) throws GeneralSecurityException {
        final byte[] kemContext = Bytes.concat(new byte[][] { senderEphemeralPublicKey, recipientPublicKey });
        return this.extractAndExpand(dhSharedSecret, kemContext);
    }
    
    byte[] deriveKemSharedSecret(final byte[] dhSharedSecret, final byte[] senderEphemeralPublicKey, final byte[] recipientPublicKey, final byte[] senderPublicKey) throws GeneralSecurityException {
        final byte[] kemContext = Bytes.concat(new byte[][] { senderEphemeralPublicKey, recipientPublicKey, senderPublicKey });
        return this.extractAndExpand(dhSharedSecret, kemContext);
    }
    
    private byte[] extractAndExpand(final byte[] dhSharedSecret, final byte[] kemContext) throws GeneralSecurityException {
        final byte[] kemSuiteID = HpkeUtil.kemSuiteId(this.getKemId());
        return this.hkdf.extractAndExpand(null, dhSharedSecret, "eae_prk", kemContext, "shared_secret", kemSuiteID, this.hkdf.getMacLength());
    }
    
    HpkeKemEncapOutput encapsulate(final byte[] recipientPublicKey, final KeyPair senderEphemeralKeyPair) throws GeneralSecurityException {
        final ECPublicKey recipientECPublicKey = EllipticCurves.getEcPublicKey(this.curve, EllipticCurves.PointFormatType.UNCOMPRESSED, recipientPublicKey);
        final byte[] dhSharedSecret = EllipticCurves.computeSharedSecret((ECPrivateKey)senderEphemeralKeyPair.getPrivate(), recipientECPublicKey);
        final byte[] senderPublicKey = EllipticCurves.pointEncode(this.curve, EllipticCurves.PointFormatType.UNCOMPRESSED, ((ECPublicKey)senderEphemeralKeyPair.getPublic()).getW());
        final byte[] kemSharedSecret = this.deriveKemSharedSecret(dhSharedSecret, senderPublicKey, recipientPublicKey);
        return new HpkeKemEncapOutput(kemSharedSecret, senderPublicKey);
    }
    
    @Override
    public HpkeKemEncapOutput encapsulate(final byte[] recipientPublicKey) throws GeneralSecurityException {
        final KeyPair keyPair = EllipticCurves.generateKeyPair(this.curve);
        return this.encapsulate(recipientPublicKey, keyPair);
    }
    
    HpkeKemEncapOutput authEncapsulate(final byte[] recipientPublicKey, final KeyPair senderEphemeralKeyPair, final HpkeKemPrivateKey senderPrivateKey) throws GeneralSecurityException {
        final ECPublicKey recipientECPublicKey = EllipticCurves.getEcPublicKey(this.curve, EllipticCurves.PointFormatType.UNCOMPRESSED, recipientPublicKey);
        final ECPrivateKey privateKey = EllipticCurves.getEcPrivateKey(this.curve, senderPrivateKey.getSerializedPrivate().toByteArray());
        final byte[] dhSharedSecret = Bytes.concat(new byte[][] { EllipticCurves.computeSharedSecret((ECPrivateKey)senderEphemeralKeyPair.getPrivate(), recipientECPublicKey), EllipticCurves.computeSharedSecret(privateKey, recipientECPublicKey) });
        final byte[] senderEphemeralPublicKey = EllipticCurves.pointEncode(this.curve, EllipticCurves.PointFormatType.UNCOMPRESSED, ((ECPublicKey)senderEphemeralKeyPair.getPublic()).getW());
        final byte[] kemSharedSecret = this.deriveKemSharedSecret(dhSharedSecret, senderEphemeralPublicKey, recipientPublicKey, senderPrivateKey.getSerializedPublic().toByteArray());
        return new HpkeKemEncapOutput(kemSharedSecret, senderEphemeralPublicKey);
    }
    
    @Override
    public HpkeKemEncapOutput authEncapsulate(final byte[] recipientPublicKey, final HpkeKemPrivateKey senderPrivateKey) throws GeneralSecurityException {
        final KeyPair keyPair = EllipticCurves.generateKeyPair(this.curve);
        return this.authEncapsulate(recipientPublicKey, keyPair, senderPrivateKey);
    }
    
    @Override
    public byte[] decapsulate(final byte[] encapsulatedKey, final HpkeKemPrivateKey recipientPrivateKey) throws GeneralSecurityException {
        final ECPrivateKey privateKey = EllipticCurves.getEcPrivateKey(this.curve, recipientPrivateKey.getSerializedPrivate().toByteArray());
        final ECPublicKey publicKey = EllipticCurves.getEcPublicKey(this.curve, EllipticCurves.PointFormatType.UNCOMPRESSED, encapsulatedKey);
        final byte[] dhSharedSecret = EllipticCurves.computeSharedSecret(privateKey, publicKey);
        return this.deriveKemSharedSecret(dhSharedSecret, encapsulatedKey, recipientPrivateKey.getSerializedPublic().toByteArray());
    }
    
    @Override
    public byte[] authDecapsulate(final byte[] encapsulatedKey, final HpkeKemPrivateKey recipientPrivateKey, final byte[] senderPublicKey) throws GeneralSecurityException {
        final ECPrivateKey privateKey = EllipticCurves.getEcPrivateKey(this.curve, recipientPrivateKey.getSerializedPrivate().toByteArray());
        final ECPublicKey senderEphemeralPublicKey = EllipticCurves.getEcPublicKey(this.curve, EllipticCurves.PointFormatType.UNCOMPRESSED, encapsulatedKey);
        final byte[] dhSharedSecret = Bytes.concat(new byte[][] { EllipticCurves.computeSharedSecret(privateKey, senderEphemeralPublicKey), EllipticCurves.computeSharedSecret(privateKey, EllipticCurves.getEcPublicKey(this.curve, EllipticCurves.PointFormatType.UNCOMPRESSED, senderPublicKey)) });
        return this.deriveKemSharedSecret(dhSharedSecret, encapsulatedKey, recipientPrivateKey.getSerializedPublic().toByteArray(), senderPublicKey);
    }
    
    @Override
    public byte[] getKemId() throws GeneralSecurityException {
        switch (this.curve) {
            case NIST_P256: {
                return HpkeUtil.P256_HKDF_SHA256_KEM_ID;
            }
            case NIST_P384: {
                return HpkeUtil.P384_HKDF_SHA384_KEM_ID;
            }
            case NIST_P521: {
                return HpkeUtil.P521_HKDF_SHA512_KEM_ID;
            }
            default: {
                throw new GeneralSecurityException("Could not determine HPKE KEM ID");
            }
        }
    }
}
