// 
// Decompiled by Procyon v0.6.0
// 

package com.google.crypto.tink.hybrid.internal;

import com.google.crypto.tink.internal.BigIntegerEncoding;
import com.google.crypto.tink.AccessesPartialKey;
import com.google.crypto.tink.hybrid.HpkePublicKey;
import java.security.GeneralSecurityException;
import com.google.crypto.tink.subtle.Bytes;
import javax.annotation.concurrent.GuardedBy;
import java.math.BigInteger;
import javax.annotation.concurrent.ThreadSafe;

@ThreadSafe
public final class HpkeContext
{
    private static final byte[] EMPTY_IKM;
    private final HpkeAead aead;
    private final BigInteger maxSequenceNumber;
    private final byte[] key;
    private final byte[] baseNonce;
    private final byte[] encapsulatedKey;
    @GuardedBy("this")
    private BigInteger sequenceNumber;
    
    private HpkeContext(final byte[] encapsulatedKey, final byte[] key, final byte[] baseNonce, final BigInteger maxSequenceNumber, final HpkeAead aead) {
        this.encapsulatedKey = encapsulatedKey;
        this.key = key;
        this.baseNonce = baseNonce;
        this.sequenceNumber = BigInteger.ZERO;
        this.maxSequenceNumber = maxSequenceNumber;
        this.aead = aead;
    }
    
    static HpkeContext createContext(final byte[] mode, final byte[] encapsulatedKey, final byte[] sharedSecret, final HpkeKem kem, final HpkeKdf kdf, final HpkeAead aead, final byte[] info) throws GeneralSecurityException {
        final byte[] suiteId = HpkeUtil.hpkeSuiteId(kem.getKemId(), kdf.getKdfId(), aead.getAeadId());
        final byte[] pskIdHash = kdf.labeledExtract(HpkeUtil.EMPTY_SALT, HpkeContext.EMPTY_IKM, "psk_id_hash", suiteId);
        final byte[] infoHash = kdf.labeledExtract(HpkeUtil.EMPTY_SALT, info, "info_hash", suiteId);
        final byte[] keyScheduleContext = Bytes.concat(new byte[][] { mode, pskIdHash, infoHash });
        final byte[] secret = kdf.labeledExtract(sharedSecret, HpkeContext.EMPTY_IKM, "secret", suiteId);
        final byte[] key = kdf.labeledExpand(secret, keyScheduleContext, "key", suiteId, aead.getKeyLength());
        final byte[] baseNonce = kdf.labeledExpand(secret, keyScheduleContext, "base_nonce", suiteId, aead.getNonceLength());
        final BigInteger maxSeqNo = maxSequenceNumber(aead.getNonceLength());
        return new HpkeContext(encapsulatedKey, key, baseNonce, maxSeqNo, aead);
    }
    
    static HpkeContext createSenderContext(final byte[] recipientPublicKey, final HpkeKem kem, final HpkeKdf kdf, final HpkeAead aead, final byte[] info) throws GeneralSecurityException {
        final HpkeKemEncapOutput encapOutput = kem.encapsulate(recipientPublicKey);
        final byte[] encapsulatedKey = encapOutput.getEncapsulatedKey();
        final byte[] sharedSecret = encapOutput.getSharedSecret();
        return createContext(HpkeUtil.BASE_MODE, encapsulatedKey, sharedSecret, kem, kdf, aead, info);
    }
    
    @AccessesPartialKey
    public static HpkeContext createAuthSenderContext(final HpkePublicKey recipientPublicKey, final HpkeKem kem, final HpkeKdf kdf, final HpkeAead aead, final byte[] info, final HpkeKemPrivateKey senderPrivateKey) throws GeneralSecurityException {
        final HpkeKemEncapOutput encapOutput = kem.authEncapsulate(recipientPublicKey.getPublicKeyBytes().toByteArray(), senderPrivateKey);
        final byte[] encapsulatedKey = encapOutput.getEncapsulatedKey();
        final byte[] sharedSecret = encapOutput.getSharedSecret();
        return createContext(HpkeUtil.AUTH_MODE, encapsulatedKey, sharedSecret, kem, kdf, aead, info);
    }
    
    public static HpkeContext createRecipientContext(final byte[] encapsulatedKey, final HpkeKemPrivateKey recipientPrivateKey, final HpkeKem kem, final HpkeKdf kdf, final HpkeAead aead, final byte[] info) throws GeneralSecurityException {
        final byte[] sharedSecret = kem.decapsulate(encapsulatedKey, recipientPrivateKey);
        return createContext(HpkeUtil.BASE_MODE, encapsulatedKey, sharedSecret, kem, kdf, aead, info);
    }
    
    @AccessesPartialKey
    public static HpkeContext createAuthRecipientContext(final byte[] encapsulatedKey, final HpkeKemPrivateKey recipientPrivateKey, final HpkeKem kem, final HpkeKdf kdf, final HpkeAead aead, final byte[] info, final HpkePublicKey senderPublicKey) throws GeneralSecurityException {
        final byte[] sharedSecret = kem.authDecapsulate(encapsulatedKey, recipientPrivateKey, senderPublicKey.getPublicKeyBytes().toByteArray());
        return createContext(HpkeUtil.AUTH_MODE, encapsulatedKey, sharedSecret, kem, kdf, aead, info);
    }
    
    private static BigInteger maxSequenceNumber(final int nonceLength) {
        return BigInteger.ONE.shiftLeft(8 * nonceLength).subtract(BigInteger.ONE);
    }
    
    @GuardedBy("this")
    private void incrementSequenceNumber() throws GeneralSecurityException {
        if (this.sequenceNumber.compareTo(this.maxSequenceNumber) >= 0) {
            throw new GeneralSecurityException("message limit reached");
        }
        this.sequenceNumber = this.sequenceNumber.add(BigInteger.ONE);
    }
    
    @GuardedBy("this")
    private byte[] computeNonce() throws GeneralSecurityException {
        return Bytes.xor(this.baseNonce, BigIntegerEncoding.toBigEndianBytesOfFixedLength(this.sequenceNumber, this.aead.getNonceLength()));
    }
    
    private synchronized byte[] computeNonceAndIncrementSequenceNumber() throws GeneralSecurityException {
        final byte[] nonce = this.computeNonce();
        this.incrementSequenceNumber();
        return nonce;
    }
    
    byte[] getKey() {
        return this.key;
    }
    
    byte[] getBaseNonce() {
        return this.baseNonce;
    }
    
    public byte[] getEncapsulatedKey() {
        return this.encapsulatedKey;
    }
    
    public byte[] seal(final byte[] plaintext, final byte[] associatedData) throws GeneralSecurityException {
        final byte[] nonce = this.computeNonceAndIncrementSequenceNumber();
        return this.aead.seal(this.key, nonce, plaintext, associatedData);
    }
    
    byte[] seal(final byte[] plaintext, final int ciphertextOffset, final byte[] associatedData) throws GeneralSecurityException {
        final byte[] nonce = this.computeNonceAndIncrementSequenceNumber();
        return this.aead.seal(this.key, nonce, plaintext, ciphertextOffset, associatedData);
    }
    
    public byte[] open(final byte[] ciphertext, final byte[] associatedData) throws GeneralSecurityException {
        return this.open(ciphertext, 0, associatedData);
    }
    
    byte[] open(final byte[] ciphertext, final int ciphertextOffset, final byte[] associatedData) throws GeneralSecurityException {
        final byte[] nonce = this.computeNonceAndIncrementSequenceNumber();
        return this.aead.open(this.key, nonce, ciphertext, ciphertextOffset, associatedData);
    }
    
    static {
        EMPTY_IKM = new byte[0];
    }
}
