// 
// Decompiled by Procyon v0.6.0
// 

package com.google.crypto.tink.subtle;

import java.security.PublicKey;
import java.security.Key;
import javax.crypto.KeyAgreement;
import java.security.spec.AlgorithmParameterSpec;
import java.security.KeyPairGenerator;
import java.security.KeyPair;
import java.security.spec.ECPrivateKeySpec;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.ECPublicKeySpec;
import java.security.spec.KeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import com.google.crypto.tink.internal.BigIntegerEncoding;
import java.security.spec.ECPoint;
import java.util.Arrays;
import java.security.InvalidAlgorithmParameterException;
import java.math.BigInteger;
import java.security.spec.EllipticCurve;
import java.security.interfaces.ECPrivateKey;
import java.security.GeneralSecurityException;
import java.security.interfaces.ECPublicKey;
import com.google.crypto.tink.internal.EllipticCurvesUtil;
import java.security.spec.ECParameterSpec;

public final class EllipticCurves
{
    public static ECParameterSpec getNistP256Params() {
        return EllipticCurvesUtil.NIST_P256_PARAMS;
    }
    
    public static ECParameterSpec getNistP384Params() {
        return EllipticCurvesUtil.NIST_P384_PARAMS;
    }
    
    public static ECParameterSpec getNistP521Params() {
        return EllipticCurvesUtil.NIST_P521_PARAMS;
    }
    
    static void checkPublicKey(final ECPublicKey key) throws GeneralSecurityException {
        EllipticCurvesUtil.checkPointOnCurve(key.getW(), key.getParams().getCurve());
    }
    
    public static boolean isNistEcParameterSpec(final ECParameterSpec spec) {
        return EllipticCurvesUtil.isNistEcParameterSpec(spec);
    }
    
    public static boolean isSameEcParameterSpec(final ECParameterSpec one, final ECParameterSpec two) {
        return EllipticCurvesUtil.isSameEcParameterSpec(one, two);
    }
    
    public static void validatePublicKey(final ECPublicKey publicKey, final ECPrivateKey privateKey) throws GeneralSecurityException {
        validatePublicKeySpec(publicKey, privateKey);
        EllipticCurvesUtil.checkPointOnCurve(publicKey.getW(), privateKey.getParams().getCurve());
    }
    
    static void validatePublicKeySpec(final ECPublicKey publicKey, final ECPrivateKey privateKey) throws GeneralSecurityException {
        try {
            final ECParameterSpec publicKeySpec = publicKey.getParams();
            final ECParameterSpec privateKeySpec = privateKey.getParams();
            if (!isSameEcParameterSpec(publicKeySpec, privateKeySpec)) {
                throw new GeneralSecurityException("invalid public key spec");
            }
        }
        catch (final IllegalArgumentException | NullPointerException ex) {
            throw new GeneralSecurityException(ex);
        }
    }
    
    public static BigInteger getModulus(final EllipticCurve curve) throws GeneralSecurityException {
        return EllipticCurvesUtil.getModulus(curve);
    }
    
    public static int fieldSizeInBits(final EllipticCurve curve) throws GeneralSecurityException {
        return getModulus(curve).subtract(BigInteger.ONE).bitLength();
    }
    
    public static int fieldSizeInBytes(final EllipticCurve curve) throws GeneralSecurityException {
        return (fieldSizeInBits(curve) + 7) / 8;
    }
    
    private static BigInteger modSqrt(BigInteger x, final BigInteger p) throws GeneralSecurityException {
        if (p.signum() != 1) {
            throw new InvalidAlgorithmParameterException("p must be positive");
        }
        x = x.mod(p);
        BigInteger squareRoot = null;
        if (x.equals(BigInteger.ZERO)) {
            return BigInteger.ZERO;
        }
        if (p.testBit(0) && p.testBit(1)) {
            final BigInteger q = p.add(BigInteger.ONE).shiftRight(2);
            squareRoot = x.modPow(q, p);
        }
        else if (p.testBit(0) && !p.testBit(1)) {
            BigInteger a = BigInteger.ONE;
            BigInteger d = null;
            final BigInteger q2 = p.subtract(BigInteger.ONE).shiftRight(1);
            int tries = 0;
            while (true) {
                d = a.multiply(a).subtract(x).mod(p);
                if (d.equals(BigInteger.ZERO)) {
                    return a;
                }
                final BigInteger t = d.modPow(q2, p);
                if (t.add(BigInteger.ONE).equals(p)) {
                    final BigInteger q3 = p.add(BigInteger.ONE).shiftRight(1);
                    BigInteger u = a;
                    BigInteger v = BigInteger.ONE;
                    for (int bit = q3.bitLength() - 2; bit >= 0; --bit) {
                        BigInteger tmp = u.multiply(v);
                        u = u.multiply(u).add(v.multiply(v).mod(p).multiply(d)).mod(p);
                        v = tmp.add(tmp).mod(p);
                        if (q3.testBit(bit)) {
                            tmp = u.multiply(a).add(v.multiply(d)).mod(p);
                            v = a.multiply(v).add(u).mod(p);
                            u = tmp;
                        }
                    }
                    squareRoot = u;
                    break;
                }
                if (!t.equals(BigInteger.ONE)) {
                    throw new InvalidAlgorithmParameterException("p is not prime");
                }
                a = a.add(BigInteger.ONE);
                if (++tries == 128 && !p.isProbablePrime(80)) {
                    throw new InvalidAlgorithmParameterException("p is not prime");
                }
            }
        }
        if (squareRoot != null && squareRoot.multiply(squareRoot).mod(p).compareTo(x) != 0) {
            throw new GeneralSecurityException("Could not find a modular square root");
        }
        return squareRoot;
    }
    
    private static BigInteger computeY(final BigInteger x, final boolean lsb, final EllipticCurve curve) throws GeneralSecurityException {
        final BigInteger p = getModulus(curve);
        final BigInteger a = curve.getA();
        final BigInteger b = curve.getB();
        final BigInteger rhs = x.multiply(x).add(a).multiply(x).add(b).mod(p);
        BigInteger y = modSqrt(rhs, p);
        if (lsb != y.testBit(0)) {
            y = p.subtract(y).mod(p);
        }
        return y;
    }
    
    @Deprecated
    public static BigInteger getY(final BigInteger x, final boolean lsb, final EllipticCurve curve) throws GeneralSecurityException {
        return computeY(x, lsb, curve);
    }
    
    private static byte[] toMinimalSignedNumber(final byte[] bs) {
        int start;
        for (start = 0; start < bs.length && bs[start] == 0; ++start) {}
        if (start == bs.length) {
            start = bs.length - 1;
        }
        int extraZero = 0;
        if ((bs[start] & 0x80) == 0x80) {
            extraZero = 1;
        }
        final byte[] res = new byte[bs.length - start + extraZero];
        System.arraycopy(bs, start, res, extraZero, bs.length - start);
        return res;
    }
    
    public static byte[] ecdsaIeee2Der(final byte[] ieee) throws GeneralSecurityException {
        if (ieee.length % 2 != 0 || ieee.length == 0 || ieee.length > 132) {
            throw new GeneralSecurityException("Invalid IEEE_P1363 encoding");
        }
        final byte[] r = toMinimalSignedNumber(Arrays.copyOf(ieee, ieee.length / 2));
        final byte[] s = toMinimalSignedNumber(Arrays.copyOfRange(ieee, ieee.length / 2, ieee.length));
        int offset = 0;
        final int length = 2 + r.length + 1 + 1 + s.length;
        byte[] der;
        if (length >= 128) {
            der = new byte[length + 3];
            der[offset++] = 48;
            der[offset++] = -127;
            der[offset++] = (byte)length;
        }
        else {
            der = new byte[length + 2];
            der[offset++] = 48;
            der[offset++] = (byte)length;
        }
        der[offset++] = 2;
        der[offset++] = (byte)r.length;
        System.arraycopy(r, 0, der, offset, r.length);
        offset += r.length;
        der[offset++] = 2;
        der[offset++] = (byte)s.length;
        System.arraycopy(s, 0, der, offset, s.length);
        return der;
    }
    
    public static byte[] ecdsaDer2Ieee(final byte[] der, final int ieeeLength) throws GeneralSecurityException {
        if (!isValidDerEncoding(der)) {
            throw new GeneralSecurityException("Invalid DER encoding");
        }
        final byte[] ieee = new byte[ieeeLength];
        final int length = der[1] & 0xFF;
        int offset = 2;
        if (length >= 128) {
            ++offset;
        }
        ++offset;
        final int rLength = der[offset++];
        int extraZero = 0;
        if (der[offset] == 0) {
            extraZero = 1;
        }
        System.arraycopy(der, offset + extraZero, ieee, ieeeLength / 2 - rLength + extraZero, rLength - extraZero);
        offset += rLength + 1;
        final int sLength = der[offset++];
        extraZero = 0;
        if (der[offset] == 0) {
            extraZero = 1;
        }
        System.arraycopy(der, offset + extraZero, ieee, ieeeLength - sLength + extraZero, sLength - extraZero);
        return ieee;
    }
    
    public static boolean isValidDerEncoding(final byte[] sig) {
        if (sig.length < 8) {
            return false;
        }
        if (sig[0] != 48) {
            return false;
        }
        int totalLen = sig[1] & 0xFF;
        int totalLenLen = 1;
        if (totalLen == 129) {
            totalLenLen = 2;
            totalLen = (sig[2] & 0xFF);
            if (totalLen < 128) {
                return false;
            }
        }
        else if (totalLen == 128 || totalLen > 129) {
            return false;
        }
        if (totalLen != sig.length - 1 - totalLenLen) {
            return false;
        }
        if (sig[1 + totalLenLen] != 2) {
            return false;
        }
        final int rLen = sig[1 + totalLenLen + 1] & 0xFF;
        if (1 + totalLenLen + 1 + 1 + rLen + 1 >= sig.length) {
            return false;
        }
        if (rLen == 0) {
            return false;
        }
        if ((sig[3 + totalLenLen] & 0xFF) >= 128) {
            return false;
        }
        if (rLen > 1 && sig[3 + totalLenLen] == 0 && (sig[4 + totalLenLen] & 0xFF) < 128) {
            return false;
        }
        if (sig[3 + totalLenLen + rLen] != 2) {
            return false;
        }
        final int sLen = sig[1 + totalLenLen + 1 + 1 + rLen + 1] & 0xFF;
        return 1 + totalLenLen + 1 + 1 + rLen + 1 + 1 + sLen == sig.length && sLen != 0 && (sig[5 + totalLenLen + rLen] & 0xFF) < 128 && (sLen <= 1 || sig[5 + totalLenLen + rLen] != 0 || (sig[6 + totalLenLen + rLen] & 0xFF) >= 128);
    }
    
    public static int encodingSizeInBytes(final EllipticCurve curve, final PointFormatType format) throws GeneralSecurityException {
        final int coordinateSize = fieldSizeInBytes(curve);
        switch (format.ordinal()) {
            case 0: {
                return 2 * coordinateSize + 1;
            }
            case 2: {
                return 2 * coordinateSize;
            }
            case 1: {
                return coordinateSize + 1;
            }
            default: {
                throw new GeneralSecurityException("unknown EC point format");
            }
        }
    }
    
    public static ECPoint ecPointDecode(final EllipticCurve curve, final PointFormatType format, final byte[] encoded) throws GeneralSecurityException {
        return pointDecode(curve, format, encoded);
    }
    
    public static ECPoint pointDecode(final CurveType curveType, final PointFormatType format, final byte[] encoded) throws GeneralSecurityException {
        return pointDecode(getCurveSpec(curveType).getCurve(), format, encoded);
    }
    
    public static ECPoint pointDecode(final EllipticCurve curve, final PointFormatType format, final byte[] encoded) throws GeneralSecurityException {
        final int coordinateSize = fieldSizeInBytes(curve);
        switch (format.ordinal()) {
            case 0: {
                if (encoded.length != 2 * coordinateSize + 1) {
                    throw new GeneralSecurityException("invalid point size");
                }
                if (encoded[0] != 4) {
                    throw new GeneralSecurityException("invalid point format");
                }
                final BigInteger x = new BigInteger(1, Arrays.copyOfRange(encoded, 1, coordinateSize + 1));
                final BigInteger y = new BigInteger(1, Arrays.copyOfRange(encoded, coordinateSize + 1, encoded.length));
                final ECPoint point = new ECPoint(x, y);
                EllipticCurvesUtil.checkPointOnCurve(point, curve);
                return point;
            }
            case 2: {
                if (encoded.length != 2 * coordinateSize) {
                    throw new GeneralSecurityException("invalid point size");
                }
                final BigInteger x = new BigInteger(1, Arrays.copyOf(encoded, coordinateSize));
                final BigInteger y = new BigInteger(1, Arrays.copyOfRange(encoded, coordinateSize, encoded.length));
                final ECPoint point = new ECPoint(x, y);
                EllipticCurvesUtil.checkPointOnCurve(point, curve);
                return point;
            }
            case 1: {
                final BigInteger p = getModulus(curve);
                if (encoded.length != coordinateSize + 1) {
                    throw new GeneralSecurityException("compressed point has wrong length");
                }
                boolean lsb;
                if (encoded[0] == 2) {
                    lsb = false;
                }
                else {
                    if (encoded[0] != 3) {
                        throw new GeneralSecurityException("invalid format");
                    }
                    lsb = true;
                }
                final BigInteger x2 = new BigInteger(1, Arrays.copyOfRange(encoded, 1, encoded.length));
                if (x2.signum() == -1 || x2.compareTo(p) >= 0) {
                    throw new GeneralSecurityException("x is out of range");
                }
                final BigInteger y2 = computeY(x2, lsb, curve);
                return new ECPoint(x2, y2);
            }
            default: {
                throw new GeneralSecurityException("invalid format:" + format);
            }
        }
    }
    
    public static byte[] pointEncode(final CurveType curveType, final PointFormatType format, final ECPoint point) throws GeneralSecurityException {
        return pointEncode(getCurveSpec(curveType).getCurve(), format, point);
    }
    
    public static byte[] pointEncode(final EllipticCurve curve, final PointFormatType format, final ECPoint point) throws GeneralSecurityException {
        EllipticCurvesUtil.checkPointOnCurve(point, curve);
        final int coordinateSize = fieldSizeInBytes(curve);
        switch (format.ordinal()) {
            case 0: {
                final byte[] encoded = new byte[2 * coordinateSize + 1];
                final byte[] x = BigIntegerEncoding.toBigEndianBytes(point.getAffineX());
                final byte[] y = BigIntegerEncoding.toBigEndianBytes(point.getAffineY());
                System.arraycopy(y, 0, encoded, 1 + 2 * coordinateSize - y.length, y.length);
                System.arraycopy(x, 0, encoded, 1 + coordinateSize - x.length, x.length);
                encoded[0] = 4;
                return encoded;
            }
            case 2: {
                final byte[] encoded = new byte[2 * coordinateSize];
                byte[] x = BigIntegerEncoding.toBigEndianBytes(point.getAffineX());
                if (x.length > coordinateSize) {
                    x = Arrays.copyOfRange(x, x.length - coordinateSize, x.length);
                }
                byte[] y = BigIntegerEncoding.toBigEndianBytes(point.getAffineY());
                if (y.length > coordinateSize) {
                    y = Arrays.copyOfRange(y, y.length - coordinateSize, y.length);
                }
                System.arraycopy(y, 0, encoded, 2 * coordinateSize - y.length, y.length);
                System.arraycopy(x, 0, encoded, coordinateSize - x.length, x.length);
                return encoded;
            }
            case 1: {
                final byte[] encoded = new byte[coordinateSize + 1];
                final byte[] x = BigIntegerEncoding.toBigEndianBytes(point.getAffineX());
                System.arraycopy(x, 0, encoded, 1 + coordinateSize - x.length, x.length);
                encoded[0] = (byte)(point.getAffineY().testBit(0) ? 3 : 2);
                return encoded;
            }
            default: {
                throw new GeneralSecurityException("invalid format:" + format);
            }
        }
    }
    
    public static ECParameterSpec getCurveSpec(final CurveType curve) throws NoSuchAlgorithmException {
        switch (curve.ordinal()) {
            case 0: {
                return getNistP256Params();
            }
            case 1: {
                return getNistP384Params();
            }
            case 2: {
                return getNistP521Params();
            }
            default: {
                throw new NoSuchAlgorithmException("curve not implemented:" + curve);
            }
        }
    }
    
    public static ECPublicKey getEcPublicKey(final byte[] x509PublicKey) throws GeneralSecurityException {
        final KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("EC");
        return (ECPublicKey)kf.generatePublic(new X509EncodedKeySpec(x509PublicKey));
    }
    
    public static ECPublicKey getEcPublicKey(final CurveType curve, final PointFormatType pointFormat, final byte[] publicKey) throws GeneralSecurityException {
        return getEcPublicKey(getCurveSpec(curve), pointFormat, publicKey);
    }
    
    public static ECPublicKey getEcPublicKey(final ECParameterSpec spec, final PointFormatType pointFormat, final byte[] publicKey) throws GeneralSecurityException {
        final ECPoint point = pointDecode(spec.getCurve(), pointFormat, publicKey);
        final ECPublicKeySpec pubSpec = new ECPublicKeySpec(point, spec);
        final KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("EC");
        return (ECPublicKey)kf.generatePublic(pubSpec);
    }
    
    public static ECPublicKey getEcPublicKey(final CurveType curve, final byte[] x, final byte[] y) throws GeneralSecurityException {
        final ECParameterSpec ecParams = getCurveSpec(curve);
        final BigInteger pubX = new BigInteger(1, x);
        final BigInteger pubY = new BigInteger(1, y);
        final ECPoint w = new ECPoint(pubX, pubY);
        EllipticCurvesUtil.checkPointOnCurve(w, ecParams.getCurve());
        final ECPublicKeySpec spec = new ECPublicKeySpec(w, ecParams);
        final KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("EC");
        return (ECPublicKey)kf.generatePublic(spec);
    }
    
    public static ECPrivateKey getEcPrivateKey(final byte[] pkcs8PrivateKey) throws GeneralSecurityException {
        final KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("EC");
        return (ECPrivateKey)kf.generatePrivate(new PKCS8EncodedKeySpec(pkcs8PrivateKey));
    }
    
    public static ECPrivateKey getEcPrivateKey(final CurveType curve, final byte[] keyValue) throws GeneralSecurityException {
        final ECParameterSpec ecParams = getCurveSpec(curve);
        final BigInteger privValue = BigIntegerEncoding.fromUnsignedBigEndianBytes(keyValue);
        final ECPrivateKeySpec spec = new ECPrivateKeySpec(privValue, ecParams);
        final KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("EC");
        return (ECPrivateKey)kf.generatePrivate(spec);
    }
    
    public static KeyPair generateKeyPair(final CurveType curve) throws GeneralSecurityException {
        return generateKeyPair(getCurveSpec(curve));
    }
    
    public static KeyPair generateKeyPair(final ECParameterSpec spec) throws GeneralSecurityException {
        final KeyPairGenerator keyGen = EngineFactory.KEY_PAIR_GENERATOR.getInstance("EC");
        keyGen.initialize(spec);
        return keyGen.generateKeyPair();
    }
    
    static void validateSharedSecret(final byte[] secret, final ECPrivateKey privateKey) throws GeneralSecurityException {
        final EllipticCurve privateKeyCurve = privateKey.getParams().getCurve();
        final BigInteger x = new BigInteger(1, secret);
        if (x.signum() == -1 || x.compareTo(getModulus(privateKeyCurve)) >= 0) {
            throw new GeneralSecurityException("shared secret is out of range");
        }
        final Object unused = computeY(x, true, privateKeyCurve);
    }
    
    public static byte[] computeSharedSecret(final ECPrivateKey myPrivateKey, final ECPublicKey peerPublicKey) throws GeneralSecurityException {
        validatePublicKeySpec(peerPublicKey, myPrivateKey);
        return computeSharedSecret(myPrivateKey, peerPublicKey.getW());
    }
    
    public static byte[] computeSharedSecret(final ECPrivateKey myPrivateKey, final ECPoint publicPoint) throws GeneralSecurityException {
        EllipticCurvesUtil.checkPointOnCurve(publicPoint, myPrivateKey.getParams().getCurve());
        final ECParameterSpec privSpec = myPrivateKey.getParams();
        final ECPublicKeySpec publicKeySpec = new ECPublicKeySpec(publicPoint, privSpec);
        final KeyFactory kf = EngineFactory.KEY_FACTORY.getInstance("EC");
        final PublicKey publicKey = kf.generatePublic(publicKeySpec);
        final KeyAgreement ka = EngineFactory.KEY_AGREEMENT.getInstance("ECDH");
        ka.init(myPrivateKey);
        try {
            ka.doPhase(publicKey, true);
            final byte[] secret = ka.generateSecret();
            validateSharedSecret(secret, myPrivateKey);
            return secret;
        }
        catch (final IllegalStateException ex) {
            throw new GeneralSecurityException(ex);
        }
    }
    
    private EllipticCurves() {
    }
    
    public enum PointFormatType
    {
        UNCOMPRESSED, 
        COMPRESSED, 
        DO_NOT_USE_CRUNCHY_UNCOMPRESSED;
    }
    
    public enum CurveType
    {
        NIST_P256, 
        NIST_P384, 
        NIST_P521;
    }
    
    public enum EcdsaEncoding
    {
        IEEE_P1363, 
        DER;
    }
}
