//
// Created by cflin on 4/7/23.
//

#include "StaticSim.h"
#include "Utils.hpp"
#include "Material.hpp"

#include <iostream>
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <spdlog/spdlog.h>
#include <igl/write_triangle_mesh.h>
#include <igl/read_triangle_mesh.h>
#include <igl/in_element.h>
#include <igl/AABB.h>

namespace ssim {

    StaticSim::StaticSim(const SimTargetOption &option, const std::string &jsonPath) : option_(option) {
        Config config;
        if (!config.loadFromJSON(jsonPath)) {
            spdlog::error("error on reading json file!");
            exit(-1);
        }

        DirichletBCs = config.DirichletBCs;
        NeumannBCs = config.NeumannBCs;

        Utils::elasticMatrix(config.YM, config.PR, D);
        if (!Utils::readTetMesh(config.mshFilePath, TV, TT, SF)) {
            spdlog::error("Unable to read input msh file: {:s}", config.mshFilePath);
            exit(-1);
        }
        nN = (int) TV.rows();
        nEle = (int) TT.rows();
        nDof = nN * 3;
        eleNodeNum = 4;
        eleDofNum = 12;

        // update absBBox of Boundary Conditions
        Eigen::Vector3d modelMinBBox = TV.colwise().minCoeff();
        Eigen::Vector3d modelMaxBBox = TV.colwise().maxCoeff();
        for (auto &DBC: DirichletBCs) {
            DBC.calcAbsBBox(modelMinBBox, modelMaxBBox);
        }
        for (auto &NBC: NeumannBCs) {
            NBC.calcAbsBBox(modelMinBBox, modelMaxBBox);
        }
        setBC();

        computeFeatures();
#ifdef LINSYSSOLVER_USE_CHOLMOD
        linSysSolver = std::make_shared<CHOLMODSolver<Eigen::VectorXi, Eigen::VectorXd>>();
#else
        linSysSolver = std::make_shared<EigenLibSolver<Eigen::VectorXi, Eigen::VectorXd>>();
#endif
        linSysSolver->set_pattern(vNeighbor);
        linSysSolver->analyze_pattern();
        spdlog::info("mesh constructed");
        spdlog::info("nodes number: {}, dofs number: {}, tets element number: {}", nN, nDof, nEle);

    }


    void StaticSim::simulation() {
        int ENUM_SIZE = SimTargetOption::Target::ENUM_SIZE;
        map_target_to_evaluated_.clear();
        map_target_to_evaluated_.resize(ENUM_SIZE);

        // solve
        computeK();
        solve();
        prepare_surf_result();

        // Save results to map_target_to_evaluated_
        for (int i = 0; i < ENUM_SIZE; ++i) {
            if (option_.is_option_set(i))
                MapAppendTarget(i);
        }
    }

    void StaticSim::prepare_surf_result() {
        surf_U_.resize(SVI.size(), 3);
        surf_stress_.resize(SVI.size(), 6);
        surf_vonstress_.resize(SVI.size());

        for (int svI = 0; svI < SVI.size(); ++svI) {
            int vI = SVI[svI];
            Eigen::Vector3d u = U.segment<3>(vI * 3);
            surf_U_.row(svI) = u.transpose();

            // stress
            Eigen::VectorXd stress_vI = Eigen::VectorXd::Zero(6);
            double vonstress_vI = 0;
            int cnt = 0;
            for (const auto &item: vFLoc[vI]) {
                int eleI = item.first;
                const auto &U_eleI = Utils::SubVector(U, eDof[eleI]);
                Eigen::Matrix<double, 4, 3> X = Utils::SubMatrix(TV, TT.row(eleI), Eigen::Vector3i(0, 1, 2));

                Eigen::Matrix<double, 6, 12> B;
                Material::computeB_tet(X, B);

                Eigen::VectorXd stress = D * B * U_eleI;
                stress_vI += stress;

                vonstress_vI += Utils::vonStress(stress);
                ++cnt;
            }
            surf_stress_.row(svI) = stress_vI / cnt;
            surf_vonstress_(svI) = vonstress_vI / cnt;
        }
    }

    void StaticSim::postprocess(Eigen::MatrixXd &Q,
                                Eigen::VectorXd &QU,
                                Eigen::MatrixXd &Qstress) {
        igl::AABB<Eigen::MatrixXd, 3> tree;
        tree.init(TV, TT);
        Eigen::VectorXi tetI;
        igl::in_element(TV, TT, Q, tree, tetI);

        int nQ = static_cast<int>(Q.rows());
        QU.resize(nQ);
        Qstress.resize(nQ, 6);
        for (int qI = 0; qI < nQ; ++qI) {
            int tI = tetI(qI);
            const auto &U_tI = Utils::SubVector(U, eDof[tI]);
            Eigen::Matrix<double, 4, 3> X = Utils::SubMatrix(TV, TT.row(tI), Eigen::Vector3i(0, 1, 2));

            Eigen::Matrix<double, 3, 12> N;
            Eigen::Matrix<double, 6, 12> B;
            Material::computeN_tet(Q.row(qI), X, N);
            Material::computeB_tet(X, B);

            Eigen::Vector3d u = N * U_tI;
            QU(qI) = u.norm();
            Q.row(qI) += u;
            Qstress.row(qI) = D * B * U_tI;
        }
    }

    void StaticSim::computeFeatures() {
        // compute F_surf
        int cnt = 0;
        std::unordered_map<int, int> vI2SVI;
        for (int sfI = 0; sfI < SF.rows(); ++sfI) {
            for (int j = 0; j < 3; ++j) {
                const int &vI = SF(sfI, j);
                if (!vI2SVI.count(vI)) {
                    vI2SVI[vI] = cnt++;
                    SVI.conservativeResize(cnt);
                    SVI(cnt - 1) = vI;
                }
            }
        }
        F_surf.resize(SF.rows(), 3);
        for (int sfI = 0; sfI < SF.rows(); ++sfI) {
            F_surf(sfI, 0) = vI2SVI[SF(sfI, 0)];
            F_surf(sfI, 1) = vI2SVI[SF(sfI, 1)];
            F_surf(sfI, 2) = vI2SVI[SF(sfI, 2)];
        }
        // eDof
        eDof.resize(nEle);
#ifdef USE_TBB
        tbb::parallel_for(0, nEle, 1, [&](int eI)
#else
                                  for (int eI=0; eI < nEle; ++eI)
#endif
                          {
                              Eigen::VectorXi TT_I = TT.row(eI);
                              eDof[eI].resize(12);
                              for (int i_ = 0; i_ < 4; ++i_) {
                                  eDof[eI](3 * i_) = TT_I(i_) * 3;
                                  eDof[eI](3 * i_ + 1) = TT_I(i_) * 3 + 1;
                                  eDof[eI](3 * i_ + 2) = TT_I(i_) * 3 + 2;
                              }
                          }
#ifdef USE_TBB
        );
#endif
        // vNeighbor
        vNeighbor.resize(0);
        vNeighbor.resize(nN);
        for (int eI = 0; eI < nEle; ++eI) {
            const Eigen::Matrix<int, 1, 4> &eleVInd = TT.row(eI);
            std::vector<int> eleVInd_vec{eleVInd(0), eleVInd(1), eleVInd(2), eleVInd(3)};
            for (const auto &nI: eleVInd_vec) {
                vNeighbor[nI].insert(eleVInd_vec.begin(), eleVInd_vec.end());
            }
        }
        for (int nI = 0; nI < nN; ++nI) { // remove itself
            vNeighbor[nI].erase(nI);
        }

        // vFLoc
        vFLoc.resize(0);
        vFLoc.resize(nN);
        for (int eI = 0; eI < nEle; eI++) {
            for (int _nI = 0; _nI < eleNodeNum; ++_nI) {
                const int &nI = TT(eI, _nI);
                vFLoc[nI].insert(std::make_pair(eI, _nI));
            }
        }
    }

    void StaticSim::computeK() { // assembly stiffness matrix
        spdlog::info("assembly stiffness matrix");

        std::vector<Eigen::MatrixXd> eleKe(nEle);
        std::vector<Eigen::VectorXi> vInds(nEle);
#ifdef USE_TBB
        tbb::parallel_for(0, nEle, 1, [&](int eI)
#else
                                  for (int eI=0; eI < nEle; ++eI)
#endif
                          {
                              // eleKe
                              Eigen::Matrix<double, 4, 3> X = Utils::SubMatrix(TV, TT.row(eI), Eigen::Vector3i(0, 1, 2));
                              Eigen::Matrix<double, 12, 12> eleKe_I;
                              double vol;
                              Material::computeKe_tet(X, D, eleKe_I, vol);
                              eleKe[eI] = eleKe_I;

                              // vInds
                              vInds[eI].resize(eleNodeNum);
                              for (int _nI = 0; _nI < eleNodeNum; ++_nI) {
                                  int nI = TT(eI, _nI);
                                  vInds[eI](_nI) = isDBC(nI) ? (-nI - 1) : nI;
                              }
                          }
#ifdef USE_TBB
        );
#endif

        linSysSolver->setZero();
#ifdef USE_TBB
        tbb::parallel_for(0, nN, 1, [&](int nI)
#else
                                  for (int nI = 0; nI < nN; nI++)
#endif
                          {
                              for (const auto &FLocI: vFLoc[nI]) {
                                  Utils::addBlockToMatrix<DIM_>(eleKe[FLocI.first].block(FLocI.second * DIM_, 0, DIM_, eleDofNum),
                                                                vInds[FLocI.first], FLocI.second, linSysSolver);
                              }
                          }
#ifdef USE_TBB
        );
#endif

    }

    void StaticSim::solve() {
        spdlog::info("solve");
        linSysSolver->factorize();
        linSysSolver->solve(load, U);

        compliance_ = load.dot(U);
        spdlog::info("compliance C = {:e}", compliance_);
        // TV1 = TV + U.reshaped(3, nN).transpose();
        // Utils::writeTetVTK(outputPath + "deformed.vtk", TV1, TT);


        // compute stress on vertices
#if 0
        Eigen::MatrixXd v_stress(nV, 6);
    Eigen::VectorXd v_num(nV);
    v_num.setZero();
    for(int eleI=0; eleI < nT; ++eleI){
        Eigen::Matrix<double, 4, 3> X = TV(TT.row(eleI), Eigen::all);

        Eigen::Matrix<double, 6, 12> Be;
        Material::computeB_tet(X, Be);

        const Eigen::Matrix<int, 1, 4>& eleVInd = TT.row(eleI);
        Eigen::VectorXi edof(12);
        edof << eleVInd(0)*3, eleVInd(0)*3+1, eleVInd(0)*3+2,
                eleVInd(1)*3, eleVInd(1)*3+1, eleVInd(1)*3+2,
                eleVInd(2)*3, eleVInd(2)*3+1, eleVInd(2)*3+2,
                eleVInd(3)*3, eleVInd(3)*3+1, eleVInd(3)*3+2;

        Eigen::Vector<double, 6> ele_stress = D * Be * U(edof);
        v_stress.row(eleVInd(0)) += ele_stress;
        v_num(eleVInd(0)) += 1.0;
        v_stress.row(eleVInd(1)) += ele_stress;
        v_num(eleVInd(1)) += 1.0;
        v_stress.row(eleVInd(2)) += ele_stress;
        v_num(eleVInd(2)) += 1.0;
        v_stress.row(eleVInd(3)) += ele_stress;
        v_num(eleVInd(3)) += 1.0;
    }

    for (int vI=0; vI < nV; ++vI){
        if (v_num(vI) < 1.0) {
            spdlog::error("v_num(vI) < 1.0");
            exit(-1);
        }
        v_stress.row(vI) /= v_num(vI);
    }
    Utils::writeMatrixXd("/home/cw/MyCode/FEM_cpp/fem3d_linear_tet/output/v_stress.txt", v_stress);
    Utils::writeMatrixXd("/home/cw/MyCode/FEM_cpp/fem3d_linear_tet/output/v_stress_.txt", v_stress.rowwise().norm());
#endif

    }

    void StaticSim::setBC() {
        spdlog::info("set Boundary Conditions");

        // DBC
        int nDBC = 0;
        DBC_nI.resize(nN);
        isDBC.setZero(nN);
        int DBCNum = (int) DirichletBCs.size();
        for (int nI = 0; nI < nN; ++nI) {
            Eigen::Vector3d p = TV.row(nI);
            for (int _i = 0; _i < DBCNum; ++_i) {
                if (DirichletBCs[_i].inDBC(p)) {
                    DBC_nI(nDBC) = nI;
                    isDBC(nI) = 1;
                    ++nDBC;

                    break;
                }
            }
        }
        DBC_nI.conservativeResize(nDBC);
        // Utils::writeOBJ(outputPath + "DBCV.obj", TV(DBC_nI, Eigen::all),
        //                 Eigen::VectorXi::LinSpaced(nDBC, 0, nDBC-1));
        // NBC
        load.resize(0);
        load.setZero(nDof);
        int nNBC = 0;
        Eigen::VectorXi NBC_nI(nN);
        int NBCNum = (int) NeumannBCs.size();
        for (int nI = 0; nI < nN; ++nI) {
            Eigen::Vector3d p = TV.row(nI);
            for (int _i = 0; _i < NBCNum; ++_i) {
                if (NeumannBCs[_i].inNBC(p)) {
                    load.segment<DIM_>(nI * DIM_) = NeumannBCs[_i].force;
                    NBC_nI(nNBC) = nI;
                    ++nNBC;

                    break;
                }
            }
        }
        NBC_nI.conservativeResize(nNBC);
        // Utils::writeOBJ(outputPath + "NBCV.obj", TV(NBC_nI, Eigen::all),
        //                 Eigen::VectorXi::LinSpaced(nNBC, 0, nNBC-1));

        spdlog::info("#DBC nodes: {}, #NBC particles: {}", nDBC, nNBC);

        // ensure (DBC intersect NBC) = (empty)
        for (int i_ = 0; i_ < DBC_nI.size(); ++i_) {
            int nI = DBC_nI(i_);
            load.segment<DIM_>(nI * DIM_).setZero();
        }
    }

    Model StaticSim::get_mesh() {
        if (mesh_.NumVertex() == 0) {
            // fill mesh_
            Eigen::MatrixXd V_surf(SVI.size(), 3);
            for (int svI = 0; svI < SVI.size(); ++svI) {
                int vI = SVI(svI);
                V_surf.row(svI) = TV.row(vI);
            }
            mesh_.V = V_surf;
            mesh_.F = F_surf;
        }
        return mesh_;
    }

    void StaticSim::MapAppendTarget(int target) {
        switch (target) {
            case SimTargetOption::U_NORM:
                map_target_to_evaluated_[target] = EvaluateUNorm();
                break;
            case SimTargetOption::UX:
                map_target_to_evaluated_[target] = EvaluateUX();
                break;
            case SimTargetOption::UY:
                map_target_to_evaluated_[target] = EvaluateUY();
                break;
            case SimTargetOption::UZ:
                map_target_to_evaluated_[target] = EvaluateUZ();
                break;
            case SimTargetOption::S_NORM:
                map_target_to_evaluated_[target] = EvaluateSNorm();
                break;
            case SimTargetOption::S_VON_Mises:
                map_target_to_evaluated_[target] = EvaluateSVonMises();
                break;
            case SimTargetOption::SX:
                map_target_to_evaluated_[target] = EvaluateSX();
                break;
            case SimTargetOption::SY:
                map_target_to_evaluated_[target] = EvaluateSY();
                break;
            case SimTargetOption::SZ:
                map_target_to_evaluated_[target] = EvaluateSZ();
                break;
            case SimTargetOption::COMPLIANCE:
                map_target_to_evaluated_[target] = EvaluateCompliance();
                break;
            default:
                spdlog::warn("Wrong target {:d} !", target);
        }
    }

    // Return: #mesh.V.rows() x 1
    Eigen::MatrixXd StaticSim::EvaluateUNorm() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_U_.rowwise().norm();
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateUX() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_U_.col(0);
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateUY() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_U_.col(1);
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateUZ() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_U_.col(2);
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateSNorm() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_stress_.rowwise().norm();
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateSVonMises() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_vonstress_;
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateSX() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_stress_.col(0);
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateSY() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_stress_.col(1);
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateSZ() const {
        Eigen::MatrixXd ret(SVI.size(), 1);
        ret.col(0) = surf_stress_.col(2);
        return ret;
    }

    Eigen::MatrixXd StaticSim::EvaluateCompliance() const {
        Eigen::MatrixXd ret(1, 1);
        ret(0, 0) = compliance_;
        return ret;
    }

} // ssim