// 
// Decompiled by Procyon v0.6.0
// 

package com.hypixel.hytale.protocol.io;

import com.hypixel.hytale.protocol.PacketRegistry;
import com.hypixel.hytale.protocol.Packet;
import java.nio.ByteBuffer;
import io.netty.buffer.Unpooled;
import com.github.luben.zstd.Zstd;
import java.util.UUID;
import javax.annotation.Nullable;
import java.nio.charset.StandardCharsets;
import javax.annotation.Nonnull;
import io.netty.buffer.ByteBuf;
import java.nio.charset.Charset;

public final class PacketIO
{
    public static final int FRAME_HEADER_SIZE = 4;
    public static final Charset UTF8;
    public static final Charset ASCII;
    private static final int COMPRESSION_LEVEL;
    
    private PacketIO() {
    }
    
    public static float readHalfLE(@Nonnull final ByteBuf buf, final int index) {
        final short bits = buf.getShortLE(index);
        return halfToFloat(bits);
    }
    
    public static void writeHalfLE(@Nonnull final ByteBuf buf, final float value) {
        buf.writeShortLE(floatToHalf(value));
    }
    
    @Nonnull
    public static byte[] readBytes(@Nonnull final ByteBuf buf, final int offset, final int length) {
        final byte[] bytes = new byte[length];
        buf.getBytes(offset, bytes);
        return bytes;
    }
    
    @Nonnull
    public static byte[] readByteArray(@Nonnull final ByteBuf buf, final int offset, final int length) {
        final byte[] result = new byte[length];
        buf.getBytes(offset, result);
        return result;
    }
    
    @Nonnull
    public static short[] readShortArrayLE(@Nonnull final ByteBuf buf, final int offset, final int length) {
        final short[] result = new short[length];
        for (int i = 0; i < length; ++i) {
            result[i] = buf.getShortLE(offset + i * 2);
        }
        return result;
    }
    
    @Nonnull
    public static float[] readFloatArrayLE(@Nonnull final ByteBuf buf, final int offset, final int length) {
        final float[] result = new float[length];
        for (int i = 0; i < length; ++i) {
            result[i] = buf.getFloatLE(offset + i * 4);
        }
        return result;
    }
    
    @Nonnull
    public static String readFixedAsciiString(@Nonnull final ByteBuf buf, final int offset, final int length) {
        final byte[] bytes = new byte[length];
        buf.getBytes(offset, bytes);
        int end;
        for (end = 0; end < length && bytes[end] != 0; ++end) {}
        return new String(bytes, 0, end, StandardCharsets.US_ASCII);
    }
    
    @Nonnull
    public static String readFixedString(@Nonnull final ByteBuf buf, final int offset, final int length) {
        final byte[] bytes = new byte[length];
        buf.getBytes(offset, bytes);
        int end;
        for (end = 0; end < length && bytes[end] != 0; ++end) {}
        return new String(bytes, 0, end, StandardCharsets.UTF_8);
    }
    
    @Nonnull
    public static String readVarString(@Nonnull final ByteBuf buf, final int offset) {
        return readVarString(buf, offset, StandardCharsets.UTF_8);
    }
    
    @Nonnull
    public static String readVarAsciiString(@Nonnull final ByteBuf buf, final int offset) {
        return readVarString(buf, offset, StandardCharsets.US_ASCII);
    }
    
    @Nonnull
    public static String readVarString(@Nonnull final ByteBuf buf, final int offset, final Charset charset) {
        final int len = VarInt.peek(buf, offset);
        final int varIntLen = VarInt.length(buf, offset);
        final byte[] bytes = new byte[len];
        buf.getBytes(offset + varIntLen, bytes);
        return new String(bytes, charset);
    }
    
    public static int utf8ByteLength(@Nonnull final String s) {
        int len = 0;
        for (int i = 0; i < s.length(); ++i) {
            final char c = s.charAt(i);
            if (c < '\u0080') {
                ++len;
            }
            else if (c < '\u0800') {
                len += 2;
            }
            else if (Character.isHighSurrogate(c)) {
                len += 4;
                ++i;
            }
            else {
                len += 3;
            }
        }
        return len;
    }
    
    public static int stringSize(@Nonnull final String s) {
        final int len = utf8ByteLength(s);
        return VarInt.size(len) + len;
    }
    
    public static void writeFixedBytes(@Nonnull final ByteBuf buf, @Nonnull final byte[] data, final int length) {
        buf.writeBytes(data, 0, Math.min(data.length, length));
        for (int i = data.length; i < length; ++i) {
            buf.writeByte(0);
        }
    }
    
    public static void writeFixedAsciiString(@Nonnull final ByteBuf buf, @Nullable final String value, final int length) {
        if (value != null) {
            final byte[] bytes = value.getBytes(StandardCharsets.US_ASCII);
            if (bytes.length > length) {
                throw new ProtocolException("Fixed ASCII string exceeds length: " + bytes.length + " > " + length);
            }
            buf.writeBytes(bytes);
            buf.writeZero(length - bytes.length);
        }
        else {
            buf.writeZero(length);
        }
    }
    
    public static void writeFixedString(@Nonnull final ByteBuf buf, @Nullable final String value, final int length) {
        if (value != null) {
            final byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
            if (bytes.length > length) {
                throw new ProtocolException("Fixed UTF-8 string exceeds length: " + bytes.length + " > " + length);
            }
            buf.writeBytes(bytes);
            buf.writeZero(length - bytes.length);
        }
        else {
            buf.writeZero(length);
        }
    }
    
    public static void writeVarString(@Nonnull final ByteBuf buf, @Nonnull final String value, final int maxLength) {
        final byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
        if (bytes.length > maxLength) {
            throw new ProtocolException("String exceeds max bytes: " + bytes.length + " > " + maxLength);
        }
        VarInt.write(buf, bytes.length);
        buf.writeBytes(bytes);
    }
    
    public static void writeVarAsciiString(@Nonnull final ByteBuf buf, @Nonnull final String value, final int maxLength) {
        final byte[] bytes = value.getBytes(StandardCharsets.US_ASCII);
        if (bytes.length > maxLength) {
            throw new ProtocolException("String exceeds max bytes: " + bytes.length + " > " + maxLength);
        }
        VarInt.write(buf, bytes.length);
        buf.writeBytes(bytes);
    }
    
    @Nonnull
    public static UUID readUUID(@Nonnull final ByteBuf buf, final int offset) {
        final long mostSig = buf.getLong(offset);
        final long leastSig = buf.getLong(offset + 8);
        return new UUID(mostSig, leastSig);
    }
    
    public static void writeUUID(@Nonnull final ByteBuf buf, @Nonnull final UUID value) {
        buf.writeLong(value.getMostSignificantBits());
        buf.writeLong(value.getLeastSignificantBits());
    }
    
    private static float halfToFloat(final short half) {
        final int h = half & 0xFFFF;
        final int sign = h >>> 15 & 0x1;
        int exp = h >>> 10 & 0x1F;
        int mant = h & 0x3FF;
        if (exp == 0) {
            if (mant == 0) {
                return (sign == 0) ? 0.0f : -0.0f;
            }
            for (exp = 1; (mant & 0x400) == 0x0; mant <<= 1, --exp) {}
            mant &= 0x3FF;
        }
        else if (exp == 31) {
            return (mant == 0) ? ((sign == 0) ? Float.POSITIVE_INFINITY : Float.NEGATIVE_INFINITY) : Float.NaN;
        }
        final int floatBits = sign << 31 | exp + 112 << 23 | mant << 13;
        return Float.intBitsToFloat(floatBits);
    }
    
    private static short floatToHalf(final float f) {
        final int bits = Float.floatToRawIntBits(f);
        final int sign = bits >>> 16 & 0x8000;
        int val = (bits & Integer.MAX_VALUE) + 4096;
        if (val >= 1199570944) {
            if ((bits & Integer.MAX_VALUE) < 1199570944) {
                return (short)(sign | 0x7BFF);
            }
            if (val < 2139095040) {
                return (short)(sign | 0x7C00);
            }
            return (short)(sign | 0x7C00 | (bits & 0x7FFFFF) >>> 13);
        }
        else {
            if (val >= 947912704) {
                return (short)(sign | val - 939524096 >>> 13);
            }
            if (val < 855638016) {
                return (short)sign;
            }
            val = (bits & Integer.MAX_VALUE) >>> 23;
            return (short)(sign | ((bits & 0x7FFFFF) | 0x800000) + (8388608 >>> val - 102) >>> 126 - val);
        }
    }
    
    private static int compressToBuffer(@Nonnull final ByteBuf src, @Nonnull final ByteBuf dst, final int dstOffset, final int maxDstSize) {
        if (src.isDirect() && dst.isDirect()) {
            return Zstd.compress(dst.nioBuffer(dstOffset, maxDstSize), src.nioBuffer(), PacketIO.COMPRESSION_LEVEL);
        }
        final int srcSize = src.readableBytes();
        final byte[] srcBytes = new byte[srcSize];
        src.getBytes(src.readerIndex(), srcBytes);
        final byte[] compressed = Zstd.compress(srcBytes, PacketIO.COMPRESSION_LEVEL);
        dst.setBytes(dstOffset, compressed);
        return compressed.length;
    }
    
    @Nonnull
    private static ByteBuf decompressFromBuffer(@Nonnull final ByteBuf src, final int srcOffset, final int srcLength, final int maxDecompressedSize) {
        if (srcLength > maxDecompressedSize) {
            throw new ProtocolException("Compressed size " + srcLength + " exceeds max decompressed size " + maxDecompressedSize);
        }
        if (src.isDirect()) {
            final ByteBuffer srcNio = src.nioBuffer(srcOffset, srcLength);
            final long decompressedSize = Zstd.getFrameContentSize(srcNio);
            if (decompressedSize < 0L) {
                throw new ProtocolException("Invalid Zstd frame or unknown content size");
            }
            if (decompressedSize > maxDecompressedSize) {
                throw new ProtocolException("Decompressed size " + decompressedSize + " exceeds maximum " + maxDecompressedSize);
            }
            final ByteBuf dst = Unpooled.directBuffer((int)decompressedSize);
            final ByteBuffer dstNio = dst.nioBuffer(0, (int)decompressedSize);
            final int result = Zstd.decompress(dstNio, srcNio);
            if (Zstd.isError(result)) {
                dst.release();
                throw new ProtocolException("Zstd decompression failed: " + Zstd.getErrorName(result));
            }
            dst.writerIndex(result);
            return dst;
        }
        else {
            final byte[] srcBytes = new byte[srcLength];
            src.getBytes(srcOffset, srcBytes);
            final long decompressedSize = Zstd.getFrameContentSize(srcBytes);
            if (decompressedSize < 0L) {
                throw new ProtocolException("Invalid Zstd frame or unknown content size");
            }
            if (decompressedSize > maxDecompressedSize) {
                throw new ProtocolException("Decompressed size " + decompressedSize + " exceeds maximum " + maxDecompressedSize);
            }
            final byte[] decompressed = Zstd.decompress(srcBytes, (int)decompressedSize);
            return Unpooled.wrappedBuffer(decompressed);
        }
    }
    
    public static void writeFramedPacket(@Nonnull final Packet packet, @Nonnull final Class<? extends Packet> packetClass, @Nonnull final ByteBuf out, @Nonnull final PacketStatsRecorder statsRecorder) {
        final Integer id = PacketRegistry.getId(packetClass);
        if (id == null) {
            throw new ProtocolException("Unknown packet type: " + packetClass.getName());
        }
        final PacketRegistry.PacketInfo info = PacketRegistry.getById(id);
        final int lengthIndex = out.writerIndex();
        out.writeIntLE(0);
        out.writeIntLE(id);
        final ByteBuf payloadBuf = Unpooled.buffer(Math.min(info.maxSize(), 65536));
        try {
            packet.serialize(payloadBuf);
            final int serializedSize = payloadBuf.readableBytes();
            if (serializedSize > info.maxSize()) {
                throw new ProtocolException("Packet " + info.name() + " serialized to " + serializedSize + " bytes, exceeds max size " + info.maxSize());
            }
            if (info.compressed() && serializedSize > 0) {
                final int compressBound = (int)Zstd.compressBound(serializedSize);
                out.ensureWritable(compressBound);
                final int compressedSize = compressToBuffer(payloadBuf, out, out.writerIndex(), compressBound);
                if (Zstd.isError(compressedSize)) {
                    throw new ProtocolException("Zstd compression failed: " + Zstd.getErrorName(compressedSize));
                }
                if (compressedSize > 1677721600) {
                    throw new ProtocolException("Packet " + info.name() + " compressed payload size " + compressedSize + " exceeds protocol maximum");
                }
                out.writerIndex(out.writerIndex() + compressedSize);
                out.setIntLE(lengthIndex, compressedSize);
                statsRecorder.recordSend(id, serializedSize, compressedSize);
            }
            else {
                if (serializedSize > 1677721600) {
                    throw new ProtocolException("Packet " + info.name() + " payload size " + serializedSize + " exceeds protocol maximum");
                }
                out.writeBytes(payloadBuf);
                out.setIntLE(lengthIndex, serializedSize);
                statsRecorder.recordSend(id, serializedSize, 0);
            }
        }
        finally {
            payloadBuf.release();
        }
    }
    
    @Nonnull
    public static Packet readFramedPacket(@Nonnull final ByteBuf in, final int payloadLength, @Nonnull final PacketStatsRecorder statsRecorder) {
        final int packetId = in.readIntLE();
        final PacketRegistry.PacketInfo info = PacketRegistry.getById(packetId);
        if (info == null) {
            in.skipBytes(payloadLength);
            throw new ProtocolException("Unknown packet ID: " + packetId);
        }
        return readFramedPacketWithInfo(in, payloadLength, info, statsRecorder);
    }
    
    @Nonnull
    public static Packet readFramedPacketWithInfo(@Nonnull final ByteBuf in, final int payloadLength, @Nonnull final PacketRegistry.PacketInfo info, @Nonnull final PacketStatsRecorder statsRecorder) {
        int compressedSize = 0;
        ByteBuf payload;
        int uncompressedSize;
        if (info.compressed() && payloadLength > 0) {
            try {
                payload = decompressFromBuffer(in, in.readerIndex(), payloadLength, info.maxSize());
            }
            catch (final ProtocolException e) {
                in.skipBytes(payloadLength);
                throw e;
            }
            in.skipBytes(payloadLength);
            uncompressedSize = payload.readableBytes();
            compressedSize = payloadLength;
        }
        else if (payloadLength > 0) {
            payload = in.readRetainedSlice(payloadLength);
            uncompressedSize = payloadLength;
        }
        else {
            payload = Unpooled.EMPTY_BUFFER;
            uncompressedSize = 0;
        }
        try {
            final Packet packet = info.deserialize().apply(payload, 0);
            statsRecorder.recordReceive(info.id(), uncompressedSize, compressedSize);
            return packet;
        }
        finally {
            if (payloadLength > 0) {
                payload.release();
            }
        }
    }
    
    static {
        UTF8 = StandardCharsets.UTF_8;
        ASCII = StandardCharsets.US_ASCII;
        COMPRESSION_LEVEL = Integer.getInteger("hytale.protocol.compressionLevel", Zstd.defaultCompressionLevel());
    }
}
