#pragma once
#include "real.hpp"
#include <cmath>
#include <array>
#include <assert.h>

template<size_t N>
class Vec {
public:
    std::array<real, N> data;

    Vec() {
        data.fill(0);
    }

    Vec(std::initializer_list<real> values) {
        std::copy(values.begin(), values.end(), data.begin());
    }

    Vec(real... args) : data{static_cast<real>(args)...} {}

    Vec(const Vec<N> &v) {
        data = v.data;
    }

    Vec<N> &operator=(const Vec<N> &v) {
        data = v.data;
        return *this;
    }

    real &operator[](size_t index) {
        return data[index];
    }

    const real &operator[](size_t index) const {
        return data[index];
    }

    Vec<N> operator+(const Vec<N> &v) const {
        Vec<N> result;
        for (size_t i = 0; i < N; ++i) {
            result[i] = data[i] + v[i];
        }
        return result;
    }

    Vec<N> operator-(const Vec<N> &v) const {
        Vec<N> result;
        for (size_t i = 0; i < N; ++i) {
            result[i] = data[i] - v[i];
        }
        return result;
    }

    Vec<N> operator*(real s) const {
        Vec<N> result;
        for (size_t i = 0; i < N; ++i) {
            result[i] = data[i] * s;
        }
        return result;
    }

    friend Vec<N> operator*(real s, const Vec<N> &v) {
        Vec<N> result;
        for (size_t i = 0; i < N; ++i) {
            result[i] = s * v[i];
        }
        return result;
    }


    Vec<N> operator/(real s) const {
        Vec<N> result;
        for (size_t i = 0; i < N; ++i) {
            result[i] = data[i] / s;
        }
        return result;
    }

    real dot(const Vec<N> &v) const {
        real sum = 0;
        for (size_t i = 0; i < N; ++i) {
            sum += data[i] * v[i];
        }
        return sum;
    }

    real norm() const {
        return std::sqrt(dot(*this));
    }

    Vec<N> normalize() const {
        return *this / norm();
    }

    Vec<N> reflect(const Vec<N> &n) const {
        return *this - n * 2 * dot(n);
    }
};


// specialize template class Vec<3>;
template<>
class Vec<3> {
public:
    union {
        std::array<real, 3> data;

        struct {
            real x, y, z;
        };

        struct {
            real u, v, w;
        };
    };

    Vec() : data{0, 0, 0} {}

    Vec(real x, real y, real z) : data{x, y, z} {}

    Vec(const Vec &v) : data{v.data[0], v.data[1], v.data[2]} {}

    Vec &operator=(const Vec &v) {
        data[0] = v.data[0];
        data[1] = v.data[1];
        data[2] = v.data[2];
        return *this;
    }

    real operator[](size_t index) const {
        return data[index];
    }

    real &operator[](size_t index) {
        return data[index];
    }

    Vec operator+(const Vec &v) const {
        return {x + v.x, y + v.y, z + v.z};
    }

    Vec operator-(const Vec &v) const {
        return {x - v.x, y - v.y, z - v.z};
    }

    Vec operator*(real s) const {
        return {x * s, y * s, z * s};
    }

    friend Vec operator*(real s, const Vec &v) {
        return {s * v.x, s * v.y, s * v.z};
    }


    Vec operator/(real s) const {
        assert(s != 0);
        real inv = 1 / s;
        return *this * inv;
    }

    real dot(const Vec &v) const {
        return x * v.x + y * v.y + z * v.z;
    }

    real length() const {
        return std::sqrt(dot(*this));
    }

    Vec normalize() const {
            return *this / length();
    }

    Vec cross(const Vec &v) const {
        return {y * v.z - z * v.y, z * v.x - x * v.z, x * v.y - y * v.x};
    }
};

using Vec2 = Vec<2>;
using Vec3 = Vec<3>;


//class Vec3 {
//public:
//    real x, y, z;
//    Vec3(): x(0), y(0), z(0) {}
//    Vec3(real x, real y, real z): x(x), y(y), z(z) {}
//    Vec3(const Vec3& v): x(v.x), y(v.y), z(v.z) {}
//    Vec3& operator=(const Vec3& v) {
//        x = v.x;
//        y = v.y;
//        z = v.z;
//        return *this;
//    }
//    Vec3 operator+(const Vec3& v) const {
//        return Vec3(x + v.x, y + v.y, z + v.z);
//    }
//    Vec3 operator-(const Vec3& v) const {
//        return Vec3(x - v.x, y - v.y, z - v.z);
//    }
//    Vec3 operator*(real s) const {
//        return Vec3(x * s, y * s, z * s);
//    }
//    Vec3 operator*(const Vec3& v) const {
//        return Vec3(x * v.x, y * v.y, z * v.z);
//    }
//    Vec3 operator-() const {
//        return Vec3(-x, -y, -z);
//    }
//    friend Vec3 operator*(real s, const Vec3& v) {
//        return Vec3(s * v.x, s * v.y, s * v.z);
//    }
//    Vec3 operator/(real s) const {
//        return Vec3(x / s, y / s, z / s);
//    }
//    real dot(const Vec3& v) const {
//        return x * v.x + y * v.y + z * v.z;
//    }
//    Vec3 cross(const Vec3& v) const {
//        return Vec3(y * v.z - z * v.y, z * v.x - x * v.z, x * v.y - y * v.x);
//    }
//    real norm() const {
//        return sqrt(x * x + y * y + z * z);
//    }
//    Vec3 normalize() const {
//        return *this / norm();
//    }
//    Vec3 reflect(const Vec3& n) const {
//        return *this - n * 2 * dot(n);
//    }
//};