// 
// Decompiled by Procyon v0.6.0
// 

package com.google.crypto.tink.jwt;

import com.google.crypto.tink.proto.OutputPrefixType;
import java.security.GeneralSecurityException;
import com.google.gson.JsonObject;
import java.util.Optional;
import java.security.InvalidAlgorithmParameterException;
import com.google.crypto.tink.subtle.Base64;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CharacterCodingException;
import java.nio.ByteBuffer;
import com.google.crypto.tink.internal.Util;

final class JwtFormat
{
    private JwtFormat() {
    }
    
    static boolean isValidUrlsafeBase64Char(final char c) {
        return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '_';
    }
    
    static void validateUtf8(final byte[] data) throws JwtInvalidException {
        final CharsetDecoder decoder = Util.UTF_8.newDecoder();
        try {
            decoder.decode(ByteBuffer.wrap(data));
        }
        catch (final CharacterCodingException ex) {
            throw new JwtInvalidException(ex.getMessage());
        }
    }
    
    static byte[] strictUrlSafeDecode(final String encodedData) throws JwtInvalidException {
        for (int i = 0; i < encodedData.length(); ++i) {
            final char c = encodedData.charAt(i);
            if (!isValidUrlsafeBase64Char(c)) {
                throw new JwtInvalidException("invalid encoding");
            }
        }
        try {
            return Base64.urlSafeDecode(encodedData);
        }
        catch (final IllegalArgumentException ex) {
            throw new JwtInvalidException("invalid encoding: " + ex);
        }
    }
    
    private static void validateAlgorithm(final String algo) throws InvalidAlgorithmParameterException {
        switch (algo) {
            case "HS256":
            case "HS384":
            case "HS512":
            case "ES256":
            case "ES384":
            case "ES512":
            case "RS256":
            case "RS384":
            case "RS512":
            case "PS256":
            case "PS384":
            case "PS512": {
                return;
            }
            default: {
                throw new InvalidAlgorithmParameterException("invalid algorithm: " + algo);
            }
        }
    }
    
    static String createHeader(final String algorithm, final Optional<String> typeHeader, final Optional<String> kid) throws InvalidAlgorithmParameterException {
        validateAlgorithm(algorithm);
        final JsonObject header = new JsonObject();
        if (kid.isPresent()) {
            header.addProperty("kid", kid.get());
        }
        header.addProperty("alg", algorithm);
        if (typeHeader.isPresent()) {
            header.addProperty("typ", typeHeader.get());
        }
        return Base64.urlSafeEncode(header.toString().getBytes(Util.UTF_8));
    }
    
    private static void validateKidInHeader(final String expectedKid, final JsonObject parsedHeader) throws JwtInvalidException {
        final String kid = getStringHeader(parsedHeader, "kid");
        if (!kid.equals(expectedKid)) {
            throw new JwtInvalidException("invalid kid in header");
        }
    }
    
    static void validateHeader(final JsonObject parsedHeader, final String algorithmFromKey, final Optional<String> kidFromKey, final boolean allowKidAbsent) throws GeneralSecurityException {
        final String receivedAlgorithm = getStringHeader(parsedHeader, "alg");
        if (!receivedAlgorithm.equals(algorithmFromKey)) {
            throw new InvalidAlgorithmParameterException(String.format("invalid algorithm; expected %s, got %s", algorithmFromKey, receivedAlgorithm));
        }
        if (parsedHeader.has("crit")) {
            throw new JwtInvalidException("all tokens with crit headers are rejected");
        }
        final boolean headerHasKid = parsedHeader.has("kid");
        if (!headerHasKid && allowKidAbsent) {
            return;
        }
        if (!headerHasKid && !allowKidAbsent) {
            throw new JwtInvalidException("missing kid in header");
        }
        if (!kidFromKey.isPresent()) {
            return;
        }
        final String kid = getStringHeader(parsedHeader, "kid");
        if (!kid.equals(kidFromKey.get())) {
            throw new JwtInvalidException("invalid kid in header");
        }
    }
    
    static void validateHeader(final String expectedAlgorithm, final Optional<String> tinkKid, final Optional<String> customKid, final JsonObject parsedHeader) throws InvalidAlgorithmParameterException, JwtInvalidException {
        validateAlgorithm(expectedAlgorithm);
        final String algorithm = getStringHeader(parsedHeader, "alg");
        if (!algorithm.equals(expectedAlgorithm)) {
            throw new InvalidAlgorithmParameterException(String.format("invalid algorithm; expected %s, got %s", expectedAlgorithm, algorithm));
        }
        if (parsedHeader.has("crit")) {
            throw new JwtInvalidException("all tokens with crit headers are rejected");
        }
        if (tinkKid.isPresent() && customKid.isPresent()) {
            throw new JwtInvalidException("custom_kid can only be set for RAW keys.");
        }
        final boolean headerHasKid = parsedHeader.has("kid");
        if (tinkKid.isPresent()) {
            if (!headerHasKid) {
                throw new JwtInvalidException("missing kid in header");
            }
            validateKidInHeader(tinkKid.get(), parsedHeader);
        }
        if (customKid.isPresent() && headerHasKid) {
            validateKidInHeader(customKid.get(), parsedHeader);
        }
    }
    
    static Optional<String> getTypeHeader(final JsonObject header) throws JwtInvalidException {
        if (header.has("typ")) {
            return Optional.of(getStringHeader(header, "typ"));
        }
        return Optional.empty();
    }
    
    static String getStringHeader(final JsonObject header, final String name) throws JwtInvalidException {
        if (!header.has(name)) {
            throw new JwtInvalidException("header " + name + " does not exist");
        }
        if (!header.get(name).isJsonPrimitive() || !header.get(name).getAsJsonPrimitive().isString()) {
            throw new JwtInvalidException("header " + name + " is not a string");
        }
        return header.get(name).getAsString();
    }
    
    static String decodeHeader(final String headerStr) throws JwtInvalidException {
        final byte[] data = strictUrlSafeDecode(headerStr);
        validateUtf8(data);
        return new String(data, Util.UTF_8);
    }
    
    static String encodePayload(final String jsonPayload) {
        return Base64.urlSafeEncode(jsonPayload.getBytes(Util.UTF_8));
    }
    
    static String decodePayload(final String payloadStr) throws JwtInvalidException {
        final byte[] data = strictUrlSafeDecode(payloadStr);
        validateUtf8(data);
        return new String(data, Util.UTF_8);
    }
    
    static String encodeSignature(final byte[] signature) {
        return Base64.urlSafeEncode(signature);
    }
    
    static byte[] decodeSignature(final String signatureStr) throws JwtInvalidException {
        return strictUrlSafeDecode(signatureStr);
    }
    
    static Optional<String> getKid(final int keyId, final OutputPrefixType prefix) throws JwtInvalidException {
        if (prefix == OutputPrefixType.RAW) {
            return Optional.empty();
        }
        if (prefix == OutputPrefixType.TINK) {
            final byte[] bigEndianKeyId = ByteBuffer.allocate(4).putInt(keyId).array();
            return Optional.of(Base64.urlSafeEncode(bigEndianKeyId));
        }
        throw new JwtInvalidException("unsupported output prefix type");
    }
    
    static Optional<Integer> getKeyId(final String kid) {
        final byte[] encodedKeyId = Base64.urlSafeDecode(kid);
        if (encodedKeyId.length != 4) {
            return Optional.empty();
        }
        return Optional.of(ByteBuffer.wrap(encodedKeyId).getInt());
    }
    
    static Parts splitSignedCompact(final String signedCompact) throws JwtInvalidException {
        validateASCII(signedCompact);
        final int sigPos = signedCompact.lastIndexOf(46);
        if (sigPos < 0) {
            throw new JwtInvalidException("only tokens in JWS compact serialization format are supported");
        }
        final String unsignedCompact = signedCompact.substring(0, sigPos);
        final String encodedMac = signedCompact.substring(sigPos + 1);
        final byte[] mac = decodeSignature(encodedMac);
        final int payloadPos = unsignedCompact.indexOf(46);
        if (payloadPos < 0) {
            throw new JwtInvalidException("only tokens in JWS compact serialization format are supported");
        }
        final String encodedHeader = unsignedCompact.substring(0, payloadPos);
        final String encodedPayload = unsignedCompact.substring(payloadPos + 1);
        if (encodedPayload.indexOf(46) > 0) {
            throw new JwtInvalidException("only tokens in JWS compact serialization format are supported");
        }
        final String header = decodeHeader(encodedHeader);
        final String payload = decodePayload(encodedPayload);
        return new Parts(unsignedCompact, mac, header, payload);
    }
    
    static String createUnsignedCompact(final String algorithm, final Optional<String> kid, final RawJwt rawJwt) throws InvalidAlgorithmParameterException, JwtInvalidException {
        final String jsonPayload = rawJwt.getJsonPayload();
        final Optional<String> typeHeader = rawJwt.hasTypeHeader() ? Optional.of(rawJwt.getTypeHeader()) : Optional.empty();
        return createHeader(algorithm, typeHeader, kid) + "." + encodePayload(jsonPayload);
    }
    
    static String createSignedCompact(final String unsignedCompact, final byte[] signature) {
        return unsignedCompact + "." + encodeSignature(signature);
    }
    
    static void validateASCII(final String data) throws JwtInvalidException {
        for (int i = 0; i < data.length(); ++i) {
            final char c = data.charAt(i);
            if ((c & '\u0080') > 0) {
                throw new JwtInvalidException("Non ascii character");
            }
        }
    }
    
    static class Parts
    {
        String unsignedCompact;
        byte[] signatureOrMac;
        String header;
        String payload;
        
        Parts(final String unsignedCompact, final byte[] signatureOrMac, final String header, final String payload) {
            this.unsignedCompact = unsignedCompact;
            this.signatureOrMac = signatureOrMac;
            this.header = header;
            this.payload = payload;
        }
    }
}
