#pragma once
#include <sparkstack.hpp>
#include <uvector.hpp>
#include "organizer.hpp"
#include "primitive.hpp"

class Vector3D
{
public:
    Vector3D() = default;
    Vector3D(const double x, const double y, const double z)
    {
        this->m_x = x;
        this->m_y = y;
        this->m_z = z;
    }

    double length() const
    {
        return std::sqrt(this->m_x * this->m_x + this->m_y * this->m_y + this->m_z * this->m_z);
    }

    Vector3D operator/(const double scalar) const
    {
        if (scalar == 0)
        {
            throw std::runtime_error("Division by zero error.");
        }

        return Vector3D(this->m_x / scalar, this->m_y / scalar, this->m_z / scalar);
    }

    Vector3D operator*(const double scalar) const
    {
        return Vector3D(this->m_x * scalar, this->m_y * scalar, this->m_z * scalar);
    }

    Vector3D cross(const Vector3D& other) const
    {
        return Vector3D(this->m_y * other.m_z - this->m_z * other.m_y,
                        this->m_z * other.m_x - this->m_x * other.m_z,
                        this->m_x * other.m_y - this->m_y * other.m_x);
    }

    double dot(const Vector3D& other) const
    {
        return this->m_x * other.m_x + this->m_y * other.m_y + this->m_z * other.m_z;
    }

    algoim::uvector3 getUVector3Data() const
    {
        algoim::uvector3 node;
        node(0) = this->m_x;
        node(1) = this->m_y;
        node(2) = this->m_z;
        return node;
    }

    double m_x, m_y, m_z;
};

class Direction3D : public Vector3D
{
public:
    Direction3D() = default;
    Direction3D(const double x, const double y, const double z)
    {
        this->m_x = x;
        this->m_y = y;
        this->m_z = z;

        this->normalized();
    }

    Direction3D(const Vector3D& vector)
    {
        this->m_x = vector.m_x;
        this->m_y = vector.m_y;
        this->m_z = vector.m_z;

        this->normalized();
    }

    Vector3D cross(const Direction3D& other) const
    {
        return Vector3D(this->m_y * other.m_z - this->m_z * other.m_y,
                        this->m_z * other.m_x - this->m_x * other.m_z,
                        this->m_x * other.m_y - this->m_y * other.m_x);
    }

    double dot(const Direction3D& other) const
    {
        return this->m_x * other.m_x + this->m_y * other.m_y + this->m_z * other.m_z;
    }

    void normalized()
    {
        double length = this->length();
        if (std::abs(length) < 1e-8)
        {
            throw std::runtime_error("Cannot normalize a zero-length vector.");
        }

        this->m_x /= length;
        this->m_y /= length, this->m_z /= length;
    }

    bool isParallel(const Direction3D& other) const
    {
        auto cross = this->cross(other);
        return std::abs(cross.length()) < 1e-8;
    }

    Direction3D operator-() const
    {
        return Direction3D(-this->m_x, -this->m_y, -this->m_z);
    }
};

class Point3D : public Vector3D
{
public:
    Point3D() = default;
    Point3D(const double x, const double y, const double z)
    {
        this->m_x = x;
        this->m_y = y;
        this->m_z = z;
    }

    Vector3D operator-(const Point3D& other) const
    {
        return Vector3D(this->m_x - other.m_x, this->m_y - other.m_y, this->m_z - other.m_z);
    }

    Point3D operator-(const Direction3D& direction) const
    {
        return Point3D(this->m_x - direction.m_x, this->m_y - direction.m_y, this->m_z - direction.m_z);
    }

    Point3D operator-(const Vector3D& offset) const
    {
        return Point3D(this->m_x - offset.m_x, this->m_y - offset.m_y, this->m_z - offset.m_z);
    }

    Point3D operator+(const Direction3D& direction) const
    {
        return Point3D(this->m_x + direction.m_x, this->m_y + direction.m_y, this->m_z + direction.m_z);
    }

    Point3D operator+(const Vector3D& offset) const
    {
        return Point3D(this->m_x + offset.m_x, this->m_y + offset.m_y, this->m_z + offset.m_z);
    }

    double getDistance(const Point3D& other) const
    {
        return std::sqrt((other.m_x - this->m_x) * (other.m_x - this->m_x) +
                         (other.m_y - this->m_y) * (other.m_y - this->m_y) +
                         (other.m_z - this->m_z) * (other.m_z - this->m_z));
    }

    Point3D getMiddlePoint(const Point3D& other) const
    {
        return Point3D((this->m_x + other.m_x) / 2.0, (this->m_z + other.m_z) / 2.0, (this->m_z + other.m_z) / 2.0);
    }
};

typedef unsigned int BodyTag;

class Loader
{
public:
    /**
     * @brief Compute the barycentric coordinates of polygon
     * @param[in] points All points which define the polygon
     * @return The barycentric coordinates
     */
    Point3D computePolygonCentroid(const std::vector<Point3D>& points) const
    {
        double centroidX = 0, centroidY = 0, centroidZ = 0;
        for (const auto& point : points)
        {
            centroidX += point.m_x;
            centroidY += point.m_y;
            centroidZ += point.m_z;
        }
        int n = points.size();
        return Point3D(centroidX / n, centroidY / n, centroidZ / n);
    }

    /**
     * @brief Create an empty blob tree
     * @return The created empty blob tree
     */
    algoim::organizer::BlobTree createEmptyBlobTree()
    {
        algoim::organizer::BlobTree tree;

        algoim::organizer::Blob blob0;
        blob0.isPrimitive = 1;
        blob0.nodeOp = 0;
        blob0.inOut = 0;
        blob0.oneChildInOut = 0;
        blob0.isLeft = 1;
        blob0.ancestor = 2;
        tree.structure.push_back(blob0);

        algoim::organizer::Blob blob1;
        blob0.isPrimitive = 1;
        blob0.nodeOp = 0;
        blob0.inOut = 0;
        blob0.oneChildInOut = 0;
        blob0.isLeft = 0;
        blob0.ancestor = 0;
        tree.structure.push_back(blob1);

        algoim::organizer::Blob blob2;
        blob0.isPrimitive = 0;
        blob0.nodeOp = 3; // no set
        blob0.inOut = 0;
        blob0.oneChildInOut = 0;
        blob0.isLeft = 0;
        blob0.ancestor = 0;
        tree.structure.push_back(blob2);

        tree.primitiveNodeIdx.push_back(0);
        tree.primitiveNodeIdx.push_back(1);

        return tree;
    }

    /**
     * @brief Union two visible primitive node
     * @param[in] rep1 The first visible primitive node
     * @param[in] rep2 The second visible primitive node
     * @return The unioned visible primitive
     */
    algoim::organizer::VisiblePrimitiveRep unionNode(const algoim::organizer::VisiblePrimitiveRep& rep1,
                                                     const algoim::organizer::VisiblePrimitiveRep& rep2)
    {
        auto tree = createEmptyBlobTree();
        tree.structure[2].nodeOp = 0;

        const std::vector<algoim::organizer::VisiblePrimitiveRep> reps = {rep1, rep2};
        std::vector<algoim::organizer::MinimalPrimitiveRep> minimalReps;
        algoim::organizer::mergeSubtree2Leaf(tree, minimalReps, reps);

        algoim::organizer::VisiblePrimitiveRep result;
        result.subBlobTree = tree;
        result.aabb = rep1.aabb;
        result.aabb.extend(rep2.aabb);

        for (auto& iter : minimalReps)
        {
            result.tensors.push_back(iter.tensor);
        }

        return result;
    }

    /**
     * @brief Intersect two visible primitive node
     * @param[in] rep1 The first visible primitive node
     * @param[in] rep2 The second visible primitive node
     * @return The intersected visible primitive
     */
    algoim::organizer::VisiblePrimitiveRep intersectNode(const algoim::organizer::VisiblePrimitiveRep& rep1,
                                                         const algoim::organizer::VisiblePrimitiveRep& rep2)
    {
        auto tree = createEmptyBlobTree();
        tree.structure[2].nodeOp = 1;

        const std::vector<algoim::organizer::VisiblePrimitiveRep> reps = {rep1, rep2};
        std::vector<algoim::organizer::MinimalPrimitiveRep> minimalReps;
        algoim::organizer::mergeSubtree2Leaf(tree, minimalReps, reps);

        algoim::organizer::VisiblePrimitiveRep result;
        result.subBlobTree = tree;
        result.aabb = rep1.aabb;
        result.aabb.intersect(rep2.aabb);

        for (auto& iter : minimalReps)
        {
            result.tensors.push_back(iter.tensor);
        }

        return result;
    }

    /**
     * @brief Difference two visible primitive node
     * @param[in] rep1 The first visible primitive node
     * @param[in] rep2 The second visible primitive node
     * @return The differenced visible primitive
     */
    algoim::organizer::VisiblePrimitiveRep differentNode(const algoim::organizer::VisiblePrimitiveRep& rep1,
                                                         const algoim::organizer::VisiblePrimitiveRep& rep2)
    {
        auto tree = createEmptyBlobTree();
        tree.structure[2].nodeOp = 2;

        const std::vector<algoim::organizer::VisiblePrimitiveRep> reps = {rep1, rep2};
        std::vector<algoim::organizer::MinimalPrimitiveRep> minimalReps;
        algoim::organizer::mergeSubtree2Leaf(tree, minimalReps, reps);

        algoim::organizer::VisiblePrimitiveRep result;
        result.subBlobTree = tree;
        result.aabb = rep1.aabb;

        for (auto& iter : minimalReps)
        {
            result.tensors.push_back(iter.tensor);
        }

        return result;
    }

    void unionNode(const BodyTag body1, const BodyTag body2)
    {
        auto& rep1 = this->m_allVisible[body1];
        auto& rep2 = this->m_allVisible[body2];

        auto result = this->unionNode(rep1, rep2);

        this->m_allVisible[body1] = result;
    }

    void intersectNode(const BodyTag body1, const BodyTag body2)
    {
        auto& rep1 = this->m_allVisible[body1];
        auto& rep2 = this->m_allVisible[body2];

        auto result = this->intersectNode(rep1, rep2);

        this->m_allVisible[body1] = result;
    }

    void differentNode(const BodyTag body1, const BodyTag body2)
    {
        auto& rep1 = this->m_allVisible[body1];
        auto& rep2 = this->m_allVisible[body2];

        auto result = this->differentNode(rep1, rep2);

        this->m_allVisible[body1] = result;
    }

    void offset(const BodyTag body, const Direction3D& directrion, const double length)
    {
        algoim::uvector<algoim::real, 3> scale = 1;
        algoim::uvector<algoim::real, 3> bias = -directrion.getUVector3Data();

        auto& rep = this->m_allVisible[body];

        for (auto& iter : rep.tensors)
        {
            algoim::organizer::detail::powerTransformation(scale, bias, iter);
        }

        rep.aabb += directrion.getUVector3Data();
    }

    void split(const BodyTag body, const Point3D& basePoint, const Direction3D& normal)
    {
        auto& rep = this->m_allVisible[body];
        auto halfPlane = this->createHalfPlane(basePoint, -normal);
        auto result = this->intersectNode(rep, halfPlane);
        this->m_allVisible[body] = result;
    }

    /**
     * @brief Create a polygonal column without top face and bottom face
     * @param[in] points All the bottom point with counter clockwise
     * @param[in] extusion The stretch direction
     * @return The polygonal column
     */
    algoim::organizer::VisiblePrimitiveRep createPolygonalColumnWithoutTopBottom(const std::vector<Point3D>& points,
                                                                                 const Vector3D& extusion)
    {
        int pointNumber = points.size();

        std::vector<algoim::uvector3> vertices;
        std::vector<int> indices;
        std::vector<int> indexInclusiveScan;

        /* All bottom point */
        for (int i = 0; i < pointNumber; i++)
        {
            vertices.push_back(points[i].getUVector3Data());
        }
        /* All top point */
        for (int i = 0; i < pointNumber; i++)
        {
            vertices.push_back((points[i] + extusion).getUVector3Data());
        }

        /* Side face */
        int index = 0;
        for (int i = 0; i < pointNumber; i++)
        {
            indices.push_back(i);
            indices.push_back((i + 1) % pointNumber);
            indices.push_back((i + 1) % pointNumber + pointNumber);
            indices.push_back(i + pointNumber);

            index += 4;
            indexInclusiveScan.push_back(index);
        }

        algoim::organizer::MeshDesc polygonalColumn(vertices, indices, indexInclusiveScan);
        algoim::organizer::VisiblePrimitiveRep result;
        result.tensors.resize(pointNumber, algoim::tensor3(nullptr, 3));
        std::vector<algoim::SparkStack<algoim::real>*> temp;
        algoim::algoimSparkAllocHeapVector(temp, result.tensors);
        algoim::organizer::makeMesh(polygonalColumn, result);

        for (auto& pointer : temp)
        {
            this->m_allPointer.push_back(pointer);
        }

        return result;
    }

    /**
     * @brief Create a cylinder column without top face and bottom face
     * @param[in] origion The origion point of bottom circle of the cylinder
     * @param[in] radius The radius of the cylinder
     * @param[in] length The length of the cylinder
     * @param[in] alignAxis The align axis of the cylinder
     * @return The cylinder column
     */
    algoim::organizer::VisiblePrimitiveRep createCylinderWithoutTopBottom(const Point3D& origion,
                                                                          const double radius,
                                                                          const double length,
                                                                          const int alignAxis)
    {
        algoim::uvector3 ext = 3;
        ext(alignAxis) = 1;

        algoim::organizer::VisiblePrimitiveRep result;
        result.tensors.resize(1, algoim::tensor3(nullptr, ext));
        std::vector<algoim::SparkStack<algoim::real>*> resultTemp;
        algoim::algoimSparkAllocHeapVector(resultTemp, result.tensors);
        this->m_allPointer.push_back(resultTemp[0]);

        algoim::organizer::CylinderDesc cylinderDesc(origion.getUVector3Data(), radius, length, alignAxis);
        algoim::organizer::VisiblePrimitiveRep cylinder;
        cylinder.tensors.resize(3, algoim::tensor3(nullptr, 3));
        cylinder.tensors[0].ext_ = ext;
        algoim::algoim_spark_alloc(algoim::real, cylinder.tensors);
        algoim::organizer::makeCylinder(cylinderDesc, cylinder);

        result.tensors[0] = cylinder.tensors[0];
        result.aabb = cylinder.aabb;
        result.subBlobTree.primitiveNodeIdx.push_back(0);
        result.subBlobTree.structure.push_back(algoim::organizer::Blob{1, 2, 0, 0, 0, 0});

        return result;
    }

    /**
     * @brief Create a half plane
     * @param[in] basePoint The base point of the plane
     * @param[in] normal The normal of the plane
     * @return The half plane
     */
    algoim::organizer::VisiblePrimitiveRep createHalfPlane(const Point3D& basePoint, const Direction3D& normal)
    {
        auto halfPlaneDesc = algoim::organizer::HalfPlaneDesc(basePoint.getUVector3Data(), normal.getUVector3Data());
        algoim::organizer::VisiblePrimitiveRep halfPlane;
        halfPlane.tensors.resize(1, algoim::tensor3(nullptr, 3));
        std::vector<algoim::SparkStack<algoim::real>*> temp;
        algoim::algoimSparkAllocHeapVector(temp, halfPlane.tensors);
        algoim::organizer::makeHalfPlane(halfPlaneDesc, halfPlane);
        this->m_allPointer.push_back(temp[0]);

        return halfPlane;
    }

    /**
     * @brief Add a extrude body to csg tree with only two points
     * @param[in] points All the bottom point which define the base face
     * @param[in] bulges All the bulge on each edge of the base face
     * @param[in] extusion The Stretch direction and length
     */
    BodyTag addExtrudeWithTwoPoint(const std::vector<Point3D>& points,
                                   const std::vector<double>& bulges,
                                   const Vector3D& extusion)
    {
        assert(bulges[0] >= 0.0 && bulges[1] >= 0.0);

        auto normal = Direction3D(extusion);
        auto& point1 = points[0];
        auto& point2 = points[1];
        auto bulge1 = bulges[0];
        auto bulge2 = bulges[1];

        algoim::organizer::VisiblePrimitiveRep rep1, rep2, result;

        auto halfDistance = point1.getDistance(point2) / 2.0;
        auto middlePoint = point1.getMiddlePoint(point2);

        auto middleToOrigion1 = normal.cross(Direction3D(point2 - point1));
        auto middleToOrigion2 = normal.cross(Direction3D(point1 - point2));

        /* Determine which axis is aligned */
        int alignAxis;
        if (normal.isParallel(Direction3D(1, 0, 0)))
        {
            alignAxis = 0;
        }
        else if (normal.isParallel(Direction3D(0, 1, 0)))
        {
            alignAxis = 1;
        }
        else if (normal.isParallel(Direction3D(0, 0, 1)))
        {
            alignAxis = 2;
        }
        else
        {
            throw std::runtime_error("Non align axis cylinder.");
        }

        auto getPrimitive = [this, normal, halfDistance, middlePoint, extusion, alignAxis](
                                const Point3D& point1, const Point3D& point2, const double bulge) {
            auto middleToOrigion = normal.cross(Direction3D(point2 - point1));
            double sinHalfTheta = 2 * bulge / (1 + bulge * bulge);
            double radius = halfDistance / sinHalfTheta;
            double scalar = std::sqrt(radius * radius - halfDistance * halfDistance);
            auto origion = middlePoint + middleToOrigion * scalar;

            /* Create the cylinder face */
            return this->createCylinderWithoutTopBottom(origion, radius, extusion.length(), alignAxis);
        };

        if (std::abs(bulge1) <= 1e-8)
        {
            assert(std::abs(bulge2) > 1e-8);
            rep1 = this->createHalfPlane(point1, -Direction3D{middleToOrigion2});
            rep2 = getPrimitive(point2, point1, bulge2);
            result = this->intersectNode(rep1, rep2);
        }
        else if (std::abs(bulge2) <= 1e-8)
        {
            assert(std::abs(bulge1) > 1e-8);
            rep1 = getPrimitive(point1, point2, bulge1);
            rep2 = this->createHalfPlane(point2, -Direction3D{middleToOrigion1});
            result = this->intersectNode(rep1, rep2);
        }
        else
        {
            rep1 = getPrimitive(point1, point2, bulge1);
            rep2 = getPrimitive(point2, point1, bulge2);

            /* if the bulge == 1 and bulge == 2, it is a cylinder */
            if (std::abs(bulge1 - 1.0) < 1e-8 && std::abs(bulge2 - 1.0) < 1e-8)
            {
                result = getPrimitive(point1, point2, bulge1);
            }
            /* if the bulge1 and bulge2 has the same symbol, it is merge */
            else if (bulge1 * bulge2 > 0.0)
            {
                result = this->intersectNode(rep1, rep2);
            }
            else if (bulge1 > 0.0)
            {
                result = this->differentNode(rep1, rep2);
            }
            else
            {
                result = this->differentNode(rep2, rep1);
            }
        }

        auto halfPlane1 = createHalfPlane(points[0], -normal);
        auto halfPlane2 = createHalfPlane(points[0] + extusion, normal);
        result = this->unionNode(result, halfPlane1);
        result = this->unionNode(result, halfPlane2);

        this->m_allVisible.push_back(result);
        return this->m_allVisible.size() - 1;
    }

    /**
     * @brief Add a extrude body to csg tree
     * @param[in] points All the bottom point which define the base face
     * @param[in] bulges All the bulge on each edge of the base face
     * @param[in] extusion The Stretch direction and length
     */
    BodyTag addExtrude(const std::vector<Point3D>& points, const std::vector<double>& bulges, const Vector3D& extusion)
    {
        int pointNumber = points.size();
        assert(pointNumber >= 2);
        if (pointNumber == 2)
        {
            return addExtrudeWithTwoPoint(points, bulges, extusion);
        }

        /* Get base polygonal column */
        auto base = createPolygonalColumnWithoutTopBottom(points, extusion);

        auto normal = Direction3D(extusion);
        for (int i = 0; i < points.size(); i++)
        {
            /* Get point and bulge data */
            auto bulge = bulges[i];
            if (std::abs(bulge) < 1e-8)
            {
                continue;
            }

            auto& point1 = points[i];
            Point3D point2;
            if (i + 1 == points.size())
            {
                point2 = points[0];
            }
            else
            {
                point2 = points[i + 1];
            }

            /* Compute the origion and radius */
            auto halfDistance = point1.getDistance(point2) / 2.0;
            auto middlePoint = point1.getMiddlePoint(point2);
            auto middleToOrigion = normal.cross(Direction3D(point2 - point1));

            double sinHalfTheta = 2 * bulge / (1 + bulge * bulge);
            double radius = halfDistance / sinHalfTheta;
            double scalar = std::sqrt(radius * radius - halfDistance * halfDistance);
            auto origion = middlePoint + middleToOrigion * scalar;

            /* Determine whether to merge or subtract */
            /* The operation is merge if flag is true, otherwise it is subtract */
            bool flag;
            auto centroidPoint = computePolygonCentroid(points);
            auto middleToCentroid = Direction3D(centroidPoint - middlePoint);
            if (middleToCentroid.dot(middleToOrigion) > 0.0)
            {
                if (bulge > 0.0)
                {
                    flag = true;
                }
                else
                {
                    flag = false;
                }
            }
            else
            {
                if (bulge > 0.0)
                {
                    flag = false;
                }
                else
                {
                    flag = true;
                }
            }

            /* Determine which axis is aligned */
            int alignAxis;
            if (normal.isParallel(Direction3D(1, 0, 0)))
            {
                alignAxis = 0;
            }
            else if (normal.isParallel(Direction3D(0, 1, 0)))
            {
                alignAxis = 1;
            }
            else if (normal.isParallel(Direction3D(0, 0, 1)))
            {
                alignAxis = 2;
            }
            else
            {
                throw std::runtime_error("Non align axis cylinder.");
            }

            /* Create the cylinder face */
            auto cylinder = createCylinderWithoutTopBottom(origion, radius, extusion.length(), alignAxis);

            /* Perform union and difference operations on the basic prismatic faces */
            if (flag)
            {
                /* Union */
                /* cylinder - half plane */
                auto halfPlane = createHalfPlane(middlePoint, middleToOrigion);
                auto subtraction = this->intersectNode(cylinder, halfPlane);

                /* base + (cylinder - column) */
                base = this->unionNode(base, subtraction);
            }
            else
            {
                /* difference */
                /* base - cylinder */
                base = this->differentNode(base, cylinder);
            }
        }

        auto halfPlane1 = createHalfPlane(points[0], -normal);
        auto halfPlane2 = createHalfPlane(points[0] + extusion, normal);
        base = this->unionNode(base, halfPlane1);
        base = this->unionNode(base, halfPlane2);

        this->m_allVisible.push_back(base);

        return this->m_allVisible.size() - 1;
    }

    BodyTag addCone(const Point3D& topPoint, const Point3D& bottomPoint, const double radius1, const double radius2)
    {
        auto bottomToTop = topPoint - bottomPoint;
        auto normal = Direction3D{bottomToTop};

        /* Determine which axis is aligned */
        int alignAxis;
        if (normal.isParallel(Direction3D(1, 0, 0)))
        {
            alignAxis = 0;
        }
        else if (normal.isParallel(Direction3D(0, 1, 0)))
        {
            alignAxis = 1;
        }
        else if (normal.isParallel(Direction3D(0, 0, 1)))
        {
            alignAxis = 2;
        }
        else
        {
            throw std::runtime_error("Non align axis cone.");
        }

        auto coneDesc = algoim::organizer::ConeDesc{bottomPoint.getUVector3Data(), radius1, radius2, alignAxis};
        algoim::organizer::VisiblePrimitiveRep cone;
        cone.tensors.resize(3, algoim::tensor3(nullptr, 3));
        std::vector<algoim::SparkStack<algoim::real>*> temp;
        algoim::algoimSparkAllocHeapVector(temp, cone.tensors);
        algoim::organizer::makeCone(coneDesc, cone);

        this->m_allVisible.push_back(cone);
        for (auto& pointer : temp)
        {
            this->m_allPointer.push_back(pointer);
        }

        return this->m_allVisible.size() - 1;
    }

    void output(const BodyTag& tag)
    {
        auto rep = this->m_allVisible[tag];
        auto result = rep.tensors[0].m(algoim::uvector3(0));
    }

private:
    std::vector<algoim::organizer::VisiblePrimitiveRep> m_allVisible;
    std::vector<algoim::SparkStack<algoim::real>*> m_allPointer;
};