// 
// Decompiled by Procyon v0.6.0
// 

package com.hypixel.hytale.component.spatial;

import it.unimi.dsi.fastutil.objects.ObjectListIterator;
import java.util.Comparator;
import com.hypixel.hytale.math.vector.Vector3d;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import javax.annotation.Nullable;
import java.util.function.Predicate;
import javax.annotation.Nonnull;
import java.util.List;

public class KDTree<T> implements SpatialStructure<T>
{
    @Nonnull
    private final List<Node<T>> nodePool;
    private int nodePoolIndex;
    @Nonnull
    private final List<List<T>> dataListPool;
    private int dataListPoolIndex;
    private int size;
    @Nonnull
    private final Predicate<T> collectionFilter;
    @Nullable
    private Node<T> root;
    
    public KDTree(@Nonnull final Predicate<T> collectionFilter) {
        this.nodePool = new ObjectArrayList<Node<T>>();
        this.nodePoolIndex = 0;
        this.dataListPool = new ObjectArrayList<List<T>>();
        this.dataListPoolIndex = 0;
        this.collectionFilter = collectionFilter;
    }
    
    @Override
    public int size() {
        return this.size;
    }
    
    @Override
    public void rebuild(@Nonnull final SpatialData<T> spatialData) {
        this.root = null;
        this.size = 0;
        final int spatialDataSize = spatialData.size();
        if (spatialDataSize == 0) {
            return;
        }
        for (int i = 0; i < this.dataListPoolIndex; ++i) {
            this.dataListPool.get(i).clear();
        }
        this.nodePoolIndex = 0;
        this.dataListPoolIndex = 0;
        spatialData.sortMorton();
        final int mid = spatialDataSize / 2;
        final int sortedIndex = spatialData.getSortedIndex(mid);
        final Vector3d vector = spatialData.getVector(sortedIndex);
        final T data = spatialData.getData(sortedIndex);
        final List<T> list = this.getPooledDataList();
        list.add(data);
        int left;
        for (left = mid - 1; left >= 0; --left) {
            final int leftSortedIndex = spatialData.getSortedIndex(left);
            final Vector3d leftVector = spatialData.getVector(leftSortedIndex);
            if (!leftVector.equals(vector)) {
                break;
            }
            final T leftData = spatialData.getData(leftSortedIndex);
            list.add(leftData);
        }
        int right;
        for (right = mid + 1; right < spatialDataSize; ++right) {
            final int rightSortedIndex = spatialData.getSortedIndex(right);
            final Vector3d rightVector = spatialData.getVector(rightSortedIndex);
            if (!rightVector.equals(vector)) {
                break;
            }
            final T rightData = spatialData.getData(rightSortedIndex);
            list.add(rightData);
        }
        this.root = this.getPooledNode(vector, list);
        if (0 < left + 1) {
            this.build0(spatialData, 0, left + 1);
        }
        if (right < spatialDataSize) {
            this.build0(spatialData, right, spatialDataSize);
        }
        this.size = spatialDataSize;
    }
    
    @Nullable
    @Override
    public T closest(@Nonnull final Vector3d point) {
        final ClosestState<T> closestState = new ClosestState<T>(null, Double.MAX_VALUE);
        this.closest0(closestState, this.root, point, 0);
        if (closestState.node == null) {
            return null;
        }
        return closestState.node.data.getFirst();
    }
    
    @Override
    public void collect(@Nonnull final Vector3d center, final double radius, @Nonnull final List<T> results) {
        final double distanceSq = radius * radius;
        this.collect0(results, this.root, center, distanceSq, 0);
    }
    
    @Override
    public void collectCylinder(@Nonnull final Vector3d center, final double radius, final double height, @Nonnull final List<T> results) {
        final double radiusSq = radius * radius;
        final double halfHeight = height / 2.0;
        this.collectCylinder0(results, this.root, center, radiusSq, halfHeight, radius, 0);
    }
    
    @Override
    public void collectBox(@Nonnull final Vector3d min, @Nonnull final Vector3d max, @Nonnull final List<T> results) {
        this.collectBox0(results, this.root, min, max, 0);
    }
    
    @Override
    public void ordered(@Nonnull final Vector3d center, final double radius, @Nonnull final List<T> results) {
        final double distanceSq = radius * radius;
        final ObjectArrayList<OrderedEntry<T>> entryResults = new ObjectArrayList<OrderedEntry<T>>();
        this.ordered0(entryResults, this.root, center, distanceSq, 0);
        entryResults.sort(Comparator.comparingDouble(o -> o.distanceSq));
        for (final OrderedEntry<T> entry : entryResults) {
            for (int i = 0, bound = entry.values.size(); i < bound; ++i) {
                final T data = entry.values.get(i);
                if (this.collectionFilter.test(data)) {
                    results.add(data);
                }
            }
        }
    }
    
    @Override
    public void ordered3DAxis(@Nonnull final Vector3d center, final double xSearchRadius, final double YSearchRadius, final double zSearchRadius, @Nonnull final List<T> results) {
        final ObjectArrayList<OrderedEntry<T>> entryResults = new ObjectArrayList<OrderedEntry<T>>();
        this._internal_ordered3DAxis(entryResults, this.root, center, xSearchRadius, YSearchRadius, zSearchRadius, 0);
        entryResults.sort(Comparator.comparingDouble(o -> o.distanceSq));
        for (final OrderedEntry<T> entry : entryResults) {
            for (int i = 0, bound = entry.values.size(); i < bound; ++i) {
                final T data = entry.values.get(i);
                if (this.collectionFilter.test(data)) {
                    results.add(data);
                }
            }
        }
    }
    
    @Nonnull
    @Override
    public String dump() {
        return "KDTree(size=" + this.size + ")\n" + ((this.root == null) ? null : this.root.dump(0));
    }
    
    @Nonnull
    private Node<T> getPooledNode(final Vector3d vector, final List<T> data) {
        if (this.nodePoolIndex < this.nodePool.size()) {
            final Node<T> node = this.nodePool.get(this.nodePoolIndex++);
            node.reset(vector, data);
            return node;
        }
        final Node<T> node = new Node<T>(vector, data);
        this.nodePool.add(node);
        ++this.nodePoolIndex;
        return node;
    }
    
    private List<T> getPooledDataList() {
        if (this.dataListPoolIndex < this.dataListPool.size()) {
            return this.dataListPool.get(this.dataListPoolIndex++);
        }
        final ObjectArrayList<T> set = new ObjectArrayList<T>(1);
        this.dataListPool.add(set);
        ++this.dataListPoolIndex;
        return set;
    }
    
    private void build0(@Nonnull final SpatialData<T> spatialData, final int start, final int end) {
        final int mid = (start + end) / 2;
        final int sortedIndex = spatialData.getSortedIndex(mid);
        final Vector3d vector = spatialData.getVector(sortedIndex);
        final T data = spatialData.getData(sortedIndex);
        final List<T> list = this.getPooledDataList();
        list.add(data);
        int left;
        for (left = mid - 1; left >= start; --left) {
            final int leftSortedIndex = spatialData.getSortedIndex(left);
            final Vector3d leftVector = spatialData.getVector(leftSortedIndex);
            if (!leftVector.equals(vector)) {
                break;
            }
            final T leftData = spatialData.getData(leftSortedIndex);
            list.add(leftData);
        }
        int right;
        for (right = mid + 1; right < end; ++right) {
            final int rightSortedIndex = spatialData.getSortedIndex(right);
            final Vector3d rightVector = spatialData.getVector(rightSortedIndex);
            if (!rightVector.equals(vector)) {
                break;
            }
            final T rightData = spatialData.getData(rightSortedIndex);
            list.add(rightData);
        }
        this.put0(this.root, vector, list, 0);
        if (start < left + 1) {
            this.build0(spatialData, start, left + 1);
        }
        if (right < end) {
            this.build0(spatialData, right, end);
        }
    }
    
    private void put0(@Nonnull final Node<T> node, @Nonnull final Vector3d vector, @Nonnull final List<T> list, final int axis) {
        if (compare(node.vector, vector, axis) < 0) {
            if (node.one == null) {
                node.one = (Node<T>)this.getPooledNode(vector, (List<T>)list);
            }
            else {
                this.put0(node.one, vector, list, (axis + 1) % 3);
            }
        }
        else if (node.two == null) {
            node.two = (Node<T>)this.getPooledNode(vector, (List<T>)list);
        }
        else {
            this.put0(node.two, vector, list, (axis + 1) % 3);
        }
    }
    
    private void closest0(@Nonnull final ClosestState<T> closestState, @Nullable final Node<T> node, @Nonnull final Vector3d vector, final int depth) {
        if (node == null) {
            return;
        }
        if (vector.equals(node.vector)) {
            closestState.distanceSq = 0.0;
            closestState.node = node;
            return;
        }
        final int axis = depth % 3;
        final int compare = compare(node.vector, vector, axis);
        final double distanceSq = node.vector.distanceSquaredTo(vector);
        if (distanceSq < closestState.distanceSq) {
            closestState.node = node;
            closestState.distanceSq = distanceSq;
        }
        final int newDepth = depth + 1;
        if (compare < 0) {
            this.closest0(closestState, node.one, vector, newDepth);
        }
        else {
            this.closest0(closestState, node.two, vector, newDepth);
        }
        final double plane = get(node.vector, axis);
        final double component = get(closestState.node.vector, axis);
        final double planeDistance = Math.abs(component - plane);
        if (planeDistance * planeDistance < closestState.distanceSq) {
            if (compare < 0) {
                this.closest0(closestState, node.two, vector, newDepth);
            }
            else {
                this.closest0(closestState, node.one, vector, newDepth);
            }
        }
    }
    
    private void collect0(@Nonnull final List<T> results, @Nullable final Node<T> node, @Nonnull final Vector3d vector, final double distanceSq, final int depth) {
        if (node == null) {
            return;
        }
        final int axis = depth % 3;
        final int compare = compare(node.vector, vector, axis);
        final double nodeDistanceSq = node.vector.distanceSquaredTo(vector);
        if (nodeDistanceSq < distanceSq) {
            for (int i = 0, bound = node.data.size(); i < bound; ++i) {
                final T data = node.data.get(i);
                if (this.collectionFilter.test(data)) {
                    results.add(data);
                }
            }
        }
        final int newDepth = depth + 1;
        if (compare < 0) {
            this.collect0(results, node.one, vector, distanceSq, newDepth);
        }
        else {
            this.collect0(results, node.two, vector, distanceSq, newDepth);
        }
        final double plane = get(node.vector, axis);
        final double component = get(vector, axis);
        final double planeDistance = Math.abs(component - plane);
        if (planeDistance * planeDistance < distanceSq) {
            if (compare < 0) {
                this.collect0(results, node.two, vector, distanceSq, newDepth);
            }
            else {
                this.collect0(results, node.one, vector, distanceSq, newDepth);
            }
        }
    }
    
    private void collectCylinder0(@Nonnull final List<T> results, @Nullable final Node<T> node, @Nonnull final Vector3d center, final double radiusSq, final double halfHeight, final double radius, final int depth) {
        if (node == null) {
            return;
        }
        final int axis = depth % 3;
        final int compare = compare(node.vector, center, axis);
        final double dy = node.vector.y - center.y;
        if (Math.abs(dy) <= halfHeight) {
            final double dx = node.vector.x - center.x;
            final double dz = node.vector.z - center.z;
            final double xzDistanceSq = dx * dx + dz * dz;
            if (xzDistanceSq <= radiusSq) {
                for (int i = 0, bound = node.data.size(); i < bound; ++i) {
                    final T data = node.data.get(i);
                    if (this.collectionFilter.test(data)) {
                        results.add(data);
                    }
                }
            }
        }
        final int newDepth = depth + 1;
        if (compare < 0) {
            this.collectCylinder0(results, node.one, center, radiusSq, halfHeight, radius, newDepth);
        }
        else {
            this.collectCylinder0(results, node.two, center, radiusSq, halfHeight, radius, newDepth);
        }
        final double plane = get(node.vector, axis);
        final double component = get(center, axis);
        final double axisRadius = (axis == 2) ? halfHeight : radius;
        if (Math.abs(component - plane) <= axisRadius) {
            if (compare < 0) {
                this.collectCylinder0(results, node.two, center, radiusSq, halfHeight, radius, newDepth);
            }
            else {
                this.collectCylinder0(results, node.one, center, radiusSq, halfHeight, radius, newDepth);
            }
        }
    }
    
    private void collectBox0(@Nonnull final List<T> results, @Nullable final Node<T> node, @Nonnull final Vector3d min, @Nonnull final Vector3d max, final int depth) {
        if (node == null) {
            return;
        }
        final int axis = depth % 3;
        if (node.vector.x >= min.x && node.vector.x <= max.x && node.vector.y >= min.y && node.vector.y <= max.y && node.vector.z >= min.z && node.vector.z <= max.z) {
            for (int i = 0, bound = node.data.size(); i < bound; ++i) {
                final T data = node.data.get(i);
                if (this.collectionFilter.test(data)) {
                    results.add(data);
                }
            }
        }
        final int newDepth = depth + 1;
        final double plane = get(node.vector, axis);
        final double minComponent = get(min, axis);
        final double maxComponent = get(max, axis);
        if (maxComponent >= plane) {
            this.collectBox0(results, node.one, min, max, newDepth);
        }
        if (minComponent <= plane) {
            this.collectBox0(results, node.two, min, max, newDepth);
        }
    }
    
    private void ordered0(@Nonnull final List<OrderedEntry<T>> results, @Nullable final Node<T> node, @Nonnull final Vector3d vector, final double distanceSq, final int depth) {
        if (node == null) {
            return;
        }
        final int axis = depth % 3;
        final int compare = compare(node.vector, vector, axis);
        final double nodeDistanceSq = node.vector.distanceSquaredTo(vector);
        if (nodeDistanceSq < distanceSq) {
            results.add(new OrderedEntry<T>(nodeDistanceSq, node.data));
        }
        final int newDepth = depth + 1;
        if (compare < 0) {
            this.ordered0(results, node.one, vector, distanceSq, newDepth);
        }
        else {
            this.ordered0(results, node.two, vector, distanceSq, newDepth);
        }
        final double plane = get(node.vector, axis);
        final double component = get(vector, axis);
        final double planeDistance = Math.abs(component - plane);
        if (planeDistance * planeDistance < distanceSq) {
            if (compare < 0) {
                this.ordered0(results, node.two, vector, distanceSq, newDepth);
            }
            else {
                this.ordered0(results, node.one, vector, distanceSq, newDepth);
            }
        }
    }
    
    private void _internal_ordered3DAxis(@Nonnull final List<OrderedEntry<T>> results, @Nullable final Node<T> node, @Nonnull final Vector3d center, final double xSearchRadius, final double ySearchRadius, final double zSearchRadius, final int depth) {
        if (node == null) {
            return;
        }
        final int axis = depth % 3;
        final boolean inCuboid = node.vector.x >= center.x - xSearchRadius && node.vector.x <= center.x + xSearchRadius && node.vector.y >= center.y - ySearchRadius && node.vector.y <= center.y + ySearchRadius && node.vector.z >= center.z - zSearchRadius && node.vector.z <= center.z + zSearchRadius;
        if (inCuboid) {
            final double nodeDistanceSq = node.vector.distanceSquaredTo(center);
            results.add(new OrderedEntry<T>(nodeDistanceSq, node.data));
        }
        final int newDepth = depth + 1;
        final int compare = compare(node.vector, center, axis);
        final Node<T> primary = (compare < 0) ? node.one : node.two;
        final Node<T> secondary = (compare < 0) ? node.two : node.one;
        this._internal_ordered3DAxis(results, primary, center, xSearchRadius, ySearchRadius, zSearchRadius, newDepth);
        final double plane = get(node.vector, axis);
        final double component = get(center, axis);
        final double radius = (axis == 0) ? xSearchRadius : ((axis == 1) ? zSearchRadius : ySearchRadius);
        if (Math.abs(component - plane) <= radius) {
            this._internal_ordered3DAxis(results, secondary, center, xSearchRadius, ySearchRadius, zSearchRadius, newDepth);
        }
    }
    
    private static int compare(@Nonnull final Vector3d v1, @Nonnull final Vector3d v2, final int axis) {
        return switch (axis) {
            case 0 -> Double.compare(v1.x, v2.x);
            case 1 -> Double.compare(v1.z, v2.z);
            case 2 -> Double.compare(v1.y, v2.y);
            default -> throw new IllegalArgumentException("Invalid axis: " + axis);
        };
    }
    
    private static double get(@Nonnull final Vector3d v, final int axis) {
        return switch (axis) {
            case 0 -> v.x;
            case 1 -> v.z;
            case 2 -> v.y;
            default -> throw new IllegalArgumentException("Invalid axis: " + axis);
        };
    }
    
    private static class Node<T>
    {
        private Vector3d vector;
        private List<T> data;
        @Nullable
        private Node<T> one;
        @Nullable
        private Node<T> two;
        
        public Node(final Vector3d vector, final List<T> data) {
            this.vector = vector;
            this.data = data;
        }
        
        public void reset(final Vector3d vector, final List<T> data) {
            this.vector = vector;
            this.data = data;
            this.one = null;
            this.two = null;
        }
        
        @Nonnull
        public String dump(final int depth) {
            final int nextDepth = depth + 1;
            return "vector=" + String.valueOf(this.vector) + ", data=" + String.valueOf(this.data) + ",\n" + " ".repeat(depth) + "one=" + ((this.one == null) ? null : this.one.dump(nextDepth)) + ",\n" + " ".repeat(depth) + "two=" + ((this.two == null) ? null : this.two.dump(nextDepth));
        }
    }
    
    private static class ClosestState<T>
    {
        private Node<T> node;
        private double distanceSq;
        
        public ClosestState(final Node<T> node, final double distanceSq) {
            this.node = node;
            this.distanceSq = distanceSq;
        }
    }
    
    private static class OrderedEntry<T>
    {
        private final double distanceSq;
        private final List<T> values;
        
        public OrderedEntry(final double distanceSq, final List<T> values) {
            this.distanceSq = distanceSq;
            this.values = values;
        }
    }
}
