#pragma once
#ifndef MALA_UTILS_H__
#define MALA_UTILS_H__
#include <core/fwd.h>
#include <core/stats.h>
#include <core/sampler.h>
#include <core/pmf.h>
#include <core/utils.h>
#include <core/logger.h>
#include <fstream>

#define GRID_X 60
#define GRID_Y 60
#define GRID_Z 60

namespace MALA {
    // struct MLTState {
    //     Vector u = Vector::Zero();
    //     Float f_u = 0;
    //     Vector g = Vector::Zero(); // d_log_f_u
    //     bool reuseable = false;
    //     MLTState() {
    //         u = Vector::Zero();
    //         f_u = 0;
    //         g = Vector::Zero();
    //     }

    //     void setZero() {
    //         u = Vector::Zero();
    //         f_u = 0;
    //         g = Vector::Zero();
    //         bool reuseable = false;
    //     }
    // };
    typedef ArrayXd MALAVector;

    Vector pss_get(const MALAVector &state, int idx) {
        assert(idx >= 0 && idx < state.size() / 3);
        return Vector(state[3 * idx], state[3 * idx + 1], state[3 * idx + 2]);
    }

    void pss_set(MALAVector &state, int idx, const Vector &v) {
        assert(idx >= 0 && idx < state.size() / 3);
        state[3 * idx] = v[0];
        state[3 * idx + 1] = v[1];
        state[3 * idx + 2] = v[2];
    }

    struct PSS_State { // wrapper for n-bounce PSS sample; each bounce uses a 3-tuple
        MALAVector u;
        Float f_u = 0;
        MALAVector g; // d_log_f_u

        PSS_State(ArrayXd vec) {
            u = vec;
            f_u = 0;
            g = MALAVector(vec.size());
        }

        PSS_State(int dim) {
            u = MALAVector(dim);
            f_u = 0;
            g = MALAVector(dim);
        }

        void setZero() {
            // for (auto &v : state)
            //     v = 0.f;
            u.setZero();
            f_u = 0;
            g.setZero();
        }
    };


    struct GridCache {
        Float grid_g[GRID_X][GRID_Y][GRID_Z][3]; // [x][y][z][g_x, g_y, g_z]
        Float grid_G[GRID_X][GRID_Y][GRID_Z][3]; // [x][y][z][G_x, G_y, G_z]
        int grid_sample_count[GRID_X][GRID_Y][GRID_Z];
        int total_sample_count = 0;
        bool write = true;
        int max_cache_size = 20000;

        GridCache(const int cache_size) {
            memset(grid_g, 0, sizeof(grid_g));
            memset(grid_G, 0, sizeof(grid_G));
            memset(grid_sample_count, 0, sizeof(grid_sample_count));
            max_cache_size = cache_size;
        }

        GridCache(const GridCache &other) {
            memcpy(grid_g, other.grid_g, sizeof(grid_g));
            memcpy(grid_G, other.grid_G, sizeof(grid_G));
            memcpy(grid_sample_count, other.grid_sample_count, sizeof(grid_sample_count));
            total_sample_count = other.total_sample_count;
        }

        void add(const Vector &u, const Vector &g, const Vector &G) {
            int grid_x = (int)(u[0] * GRID_X);
            int grid_y = (int)(u[1] * GRID_Y);
            int grid_z = (int)(u[2] * GRID_Z);
            assert(grid_x >= 0 && grid_x < GRID_X);
            assert(grid_y >= 0 && grid_y < GRID_Y);
            assert(grid_z >= 0 && grid_z < GRID_Z);
            grid_g[grid_x][grid_y][grid_z][0] += g[0];
            grid_g[grid_x][grid_y][grid_z][1] += g[1];
            grid_g[grid_x][grid_y][grid_z][2] += g[2];
            grid_G[grid_x][grid_y][grid_z][0] += G[0];
            grid_G[grid_x][grid_y][grid_z][1] += G[1];
            grid_G[grid_x][grid_y][grid_z][2] += G[2];
            ++grid_sample_count[grid_x][grid_y][grid_z];
            ++total_sample_count;
            if (total_sample_count > max_cache_size) {
                write = false;
            }
        }

        void remove(const Vector &u, const Vector &g, const Vector &G) {
            int grid_x = (int)(u[0] * GRID_X);
            int grid_y = (int)(u[1] * GRID_Y);
            int grid_z = (int)(u[2] * GRID_Z);
            assert(grid_x >= 0 && grid_x < GRID_X);
            assert(grid_y >= 0 && grid_y < GRID_Y);
            assert(grid_z >= 0 && grid_z < GRID_Z);
            grid_g[grid_x][grid_y][grid_z][0] -= g[0];
            grid_g[grid_x][grid_y][grid_z][1] -= g[1];
            grid_g[grid_x][grid_y][grid_z][2] -= g[2];
            grid_G[grid_x][grid_y][grid_z][0] -= G[0];
            grid_G[grid_x][grid_y][grid_z][1] -= G[1];
            grid_G[grid_x][grid_y][grid_z][2] -= G[2];
            --grid_sample_count[grid_x][grid_y][grid_z];
            --total_sample_count;
            if (total_sample_count <= max_cache_size) {
                write = true;
            }
        }

        bool query(const Vector &u, const Vector &g, Vector &m, Vector &G) const {
            int grid_x = (int)(u[0] * GRID_X);
            int grid_y = (int)(u[1] * GRID_Y);
            int grid_z = (int)(u[2] * GRID_Z);
            assert(grid_x >= 0 && grid_x < GRID_X);
            assert(grid_y >= 0 && grid_y < GRID_Y);
            assert(grid_z >= 0 && grid_z < GRID_Z);
            int s_inv = grid_sample_count[grid_x][grid_y][grid_z];
            if (s_inv == 0) {
                m = g;
                G = g.cwiseProduct(g);
                return false;
            } 
            m[0] = grid_g[grid_x][grid_y][grid_z][0];
            m[1] = grid_g[grid_x][grid_y][grid_z][1];
            m[2] = grid_g[grid_x][grid_y][grid_z][2];
            G[0] = grid_G[grid_x][grid_y][grid_z][0];
            G[1] = grid_G[grid_x][grid_y][grid_z][1];
            G[2] = grid_G[grid_x][grid_y][grid_z][2];
            m += g;
            G += g.cwiseProduct(g);
            Float s = 1.0 / (Float)(s_inv + 1);
            m *= s;
            G *= s;
            return true;
        }

        void merge(const GridCache &other) {
            for (int i = 0; i < GRID_X; ++i) {
                for (int j = 0; j < GRID_Y; ++j) {
                    for (int k = 0; k < GRID_Z; ++k) {
                        grid_g[i][j][k][0] += other.grid_g[i][j][k][0];
                        grid_g[i][j][k][1] += other.grid_g[i][j][k][1];
                        grid_g[i][j][k][2] += other.grid_g[i][j][k][2];
                        grid_G[i][j][k][0] += other.grid_G[i][j][k][0];
                        grid_G[i][j][k][1] += other.grid_G[i][j][k][1];
                        grid_G[i][j][k][2] += other.grid_G[i][j][k][2];
                        grid_sample_count[i][j][k] += other.grid_sample_count[i][j][k];
                    }
                }
            }
            total_sample_count += other.total_sample_count;
        }

        inline int size() const {
            return total_sample_count;
        }
    };

    struct GradNode {
        MALAVector u;
        MALAVector g;
        MALAVector G;
        GradNode() {
        }
        GradNode(const MALAVector &u_, const MALAVector &g_, const MALAVector &G_): u(u_), g(g_), G(G_) {
        }
    };

    struct GradCloud {
        struct Point
        {
            MALAVector  point;
            Point() {};
            Point(const MALAVector &u): point(u) {}
        };

        std::vector<Point>  pts;

        inline size_t kdtree_get_point_count() const { return pts.size(); }

        inline void push_back(const MALAVector &u) { 
            pts.push_back(u); 
        }

        inline Float kdtree_get_pt(const size_t idx, const size_t dim) const
        {
            assert(dim < pts[idx].point.size());
            return pts[idx].point[dim];
            // if (dim == 0) return pts[idx].x;
            // else if (dim == 1) return pts[idx].y;
            // else return pts[idx].z;
        }

        template <class BBOX>
        bool kdtree_get_bbox(BBOX& /* bb */) const { return false; }
    };


    struct KNNCache{
        // GridCache grid_cache;
        using KDtree_N = nanoflann::KDTreeSingleIndexDynamicAdaptor<nanoflann::L2_Simple_Adaptor<Float, GradCloud>, GradCloud, -1>;
        int dim = 0;
        std::vector<GradNode> grad_nodes;
        GradCloud grad_cloud;
        KDtree_N kdtree;
        Float radius = 0.05;
        bool write = true;
        int max_cache_size = 2000;
        KNNCache(int dim, int max_cache_size): max_cache_size(max_cache_size), dim(dim), grad_nodes(std::vector<GradNode>()), grad_cloud(GradCloud()), 
                            kdtree(KDtree_N(dim + 1, grad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10))) {
            grad_nodes.reserve(max_cache_size);
            grad_cloud.pts.reserve(max_cache_size);
        }

        KNNCache(const KNNCache &other): grad_nodes(other.grad_nodes), grad_cloud(other.grad_cloud), write(other.write),
                            kdtree(KDtree_N(dim + 1, grad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10))), max_cache_size(other.max_cache_size){
            kdtree.addPoints(0, grad_cloud.pts.size() - 1);
        }

        void add(const MALAVector &u, const MALAVector &g, const MALAVector &G, int cam_bounce) {
            // grid_cache.add(u, g, G);
            MALAVector u_(dim + 1), g_(dim + 1), G_(dim + 1);
            int input_dim = u.size();
            if (input_dim <= dim) {
                u_.head(input_dim) = u;
                u_.tail(dim - input_dim + 1) = -MALAVector::Ones(dim - input_dim + 1);
                u_[dim] = Float(0.0);
                g_.head(input_dim) = g;
                g_.tail(dim - input_dim + 1) = -MALAVector::Ones(dim - input_dim + 1);
                g_[dim] = Float(0.0);
                G_.head(input_dim) = G;
                G_.tail(dim - input_dim + 1) = -MALAVector::Ones(dim - input_dim + 1);
                G_[dim] = Float(0.0);
            } else {
                PSDR_INFO("Error: input_dim > dim");
                assert(false);
            }
            grad_nodes.push_back({u_, g_, G_});
            grad_cloud.push_back(u_);
            kdtree.addPoints(grad_cloud.pts.size() - 1, grad_cloud.pts.size() - 1);
            // PSDR_INFO("pts_size: {}", grad_cloud.pts.size());
            // PSDR_INFO("kdtree size: {}", kdtree.getAllIndices().size());
            // PSDR_INFO("kdtree size: {}", kdtree.getAllIndices()[0].m_size);
            // for (int i = 0; i < u.size(); i++) {
            //     PSDR_INFO("u[{}]: {}", i, u[i]);
            // }
            if (grad_cloud.pts.size() > max_cache_size) {
                write = false;
                // kdtree.addPoints(0, grad_cloud.pts.size() - 1);
            }
        }
        
        bool query(const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &G, int cam_bounce) const {
            // Vector m_, G_;
            // bool found = grid_cache.query(u, g, m_, G_);
            // m = m_;
            // G = G_;
            // if (found) {
            //     return true;
            // } else {
            //     G = g.cwiseProduct(g);
            //     m = g;
            //     // PSDR_INFO("G: {}, {}, {}", G[0], G[1], G[2]);
            //     // PSDR_INFO("m: {}, {}, {}", m[0], m[1], m[2]);
            //     // PSDR_INFO("G_: {}, {}, {}", G_[0], G_[1], G_[2]);
            //     // PSDR_INFO("m_: {}, {}, {}", m_[0], m_[1], m_[2]);
            //     return false;

            // }
            // const size_t num_results = 20;
            // size_t                         ret_index[num_results];
            // Float                          out_dist_sqr[num_results];
            // nanoflann::KNNResultSet<Float> resultSet(num_results);
            // resultSet.init(ret_index, out_dist_sqr);
            // kdtree.findNeighbors(resultSet, &u.v[0], nanoflann::SearchParams());
        //     for (size_t i = 0; i < resultSet.size(); ++i)
        // {
        //     std::cout << "#" << i << ",\t"
        //               << "index: " << ret_index[i] << ",\t"
        //               << "dist: " << out_dist_sqr[i] << ",\t"
        //               << "point: (" << cloud.pts[ret_index[i]].x << ", "
        //               << cloud.pts[ret_index[i]].y << ", "
        //               << cloud.pts[ret_index[i]].z << ")" << std::endl;
        // }
            MALAVector u_(dim + 1);
            int input_dim = u.size();
            if (input_dim <= dim) {
                u_.head(input_dim) = u;
                u_.tail(dim - input_dim + 1) = -MALAVector::Ones(dim - input_dim + 1);
                u_[dim] = Float(0.0);
            } else {
                PSDR_INFO("Error: input_dim > dim");
                assert(false);
            }
            std::vector<std::pair<size_t, Float>> indices_dists;
            nanoflann::RadiusResultSet<Float, size_t>         resultSet(radius, indices_dists);
            nanoflann::SearchParams params;
            bool found = kdtree.findNeighbors(resultSet, &u_[0], nanoflann::SearchParams());

            // PSDR_INFO("found: {}", found);
            // PSDR_INFO("resultSet: {}", resultSet.size());
            int Q_size = resultSet.size();
            Q_size = Q_size > 20 ? 20 : Q_size;
            // PSDR_INFO("query_pt: {}, {}", query_pt[0], query_pt[1]);
            G = MALAVector::Zero(g.size());
            m = MALAVector::Zero(g.size());
            if (Q_size == 0 || !found) {
                G = g.cwiseProduct(g);
                m = g;
                // PSDR_INFO("G: {}, {}, {}", G[0], G[1], G[2]);
                // PSDR_INFO("m: {}, {}, {}", m[0], m[1], m[2]);
                // PSDR_INFO("G_: {}, {}, {}", G_[0], G_[1], G_[2]);
                // PSDR_INFO("m_: {}, {}, {}", m_[0], m_[1], m_[2]);
                return false;
            }
            for (int i = 0; i < Q_size; ++i) {
                // G += grad_nodes[ret_index[i]].g.cwiseProduct(grad_nodes[ret_index[i]].g);
                // m += grad_nodes[ret_index[i]].g;
                G += grad_nodes[indices_dists[i].first].G.head(G.size());
                m += grad_nodes[indices_dists[i].first].g.head(m.size());
            }
            G /= Q_size;
            m /= Q_size;
            // G = MALA::MALAVector(G_);
            // m = MALA::MALAVector(m_);
            // PSDR_INFO("G: {}, {}, {}", G[0], G[1], G[2]);
            // PSDR_INFO("m: {}, {}, {}", m[0], m[1], m[2]);
            // PSDR_INFO("G_: {}, {}, {}", G_[0], G_[1], G_[2]);
            // PSDR_INFO("m_: {}, {}, {}", m_[0], m_[1], m_[2]);
            return true;
        }

        void merge(const KNNCache &other) {
            int n = grad_nodes.size();
            grad_nodes.insert(grad_nodes.end(), other.grad_nodes.begin(), other.grad_nodes.end());
            grad_cloud.pts.insert(grad_cloud.pts.end(), other.grad_cloud.pts.begin(), other.grad_cloud.pts.end());
            kdtree.addPoints(n, grad_cloud.pts.size() - 1);
            if (grad_cloud.pts.size() > max_cache_size) {
                write = false;
            }
        }

        void copy(const KNNCache &other) {
            grad_nodes = other.grad_nodes;
            grad_cloud = other.grad_cloud;
            // kdtree = KDtree_N(dim, grad_cloud, nanoflann::KDTreeSingleIndexAdaptorParams(10));
            kdtree.addPoints(0, grad_cloud.pts.size() - 1);
            if (grad_cloud.pts.size() > max_cache_size) {
                write = false;
            }
        }

        int size() const {
            return grad_nodes.size();
        }

    };

    struct Mutation {
        virtual void step(KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) = 0;
        virtual void step_hypo(KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) = 0;
        virtual bool step_readonly(const KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) = 0;
        virtual void setZero() = 0;
        virtual ~Mutation() {};
    };

    struct MutationDiminishing : Mutation {
        const Float beta = 0.999;
        const Float alpha = 0.9;
        const Float delta = 0.001;
        const Float c1 = 1.0;
        const Float c2 = 1.0;
        MALAVector G, d; // diagonal G
        bool first = true;
        int i = 1; // dimishing adaptation

        MutationDiminishing() {
        }

        void step(KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            if (first){
                first = false;
                G = g.cwiseProduct(g);
                d = g;
            } else {
                G = beta * G + (1 - beta) * g.cwiseProduct(g);
            }
            MALAVector H = delta * MALAVector::Ones(G.size()) + 1.0 / i * G.cwiseSqrt();
            M = H.cwiseInverse();
            d = alpha * d + (1 - alpha) * g;
            m = 1.0 / i * d + g;
            ++i;
        }

        void step_hypo(KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            MALAVector G_ = MALAVector::Zero(g.size());
            if (first){
                G_ = g.cwiseProduct(g);
            } else {
                G_ = beta * G + (1 - beta) * g.cwiseProduct(g);
            }
            MALAVector H = delta * MALAVector::Ones(g.size()) + 1.0 / i * G_.cwiseSqrt();
            M = H.cwiseInverse();
            MALAVector d_ = alpha * d + (1 - alpha) * g;
            m = 1.0 / i * d_ + g;
        }
        
        bool step_readonly(const KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            PSDR_INFO("Error: Diminishing::step_readonly shouldn't be called");
            assert(false); // shouldn't be called
            return true;
        }

        void setZero() override{
            G.setZero();
            d.setZero();
            first = true;
            i = 1;
        }
    };

    struct MutationCacheBased : Mutation {
        const Float delta = 0.001;
        int KDTREE_SIZE = 1000;
        Float query_radius = 0.05;

        MutationCacheBased(){
        }

        void step_knn(const KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) const {
            MALAVector G = MALAVector::Zero(g.size());
            cache.query(u, g, m, G, cam_bounce);
            // Get worst (furthest) point, without sorting:
            MALAVector H = delta * MALAVector::Ones(g.size()) + G.cwiseSqrt();
            M = H.cwiseInverse();
        }

        void step(KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            if (cache.write) { 
                cache.add(u, g, g.cwiseProduct(g), cam_bounce);
            } 
            
            step_knn(cache, u, g, m, M, cam_bounce);
        }

        void step_hypo(KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            
            step_knn(cache, u, g, m, M, cam_bounce);
        }

        bool step_readonly(const KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            return cache.query(u, g, m, M, cam_bounce);
        }

        void setZero() override{
        }
    };


    struct MutationHybrid : Mutation {
        const Float beta = 0.999;
        const Float alpha = 0.9;
        const Float delta = 0.001;
        // int KDTREE_SIZE = 1000;
        bool first = true;
        MALAVector G, d; // diagonal G
        // Float query_radius = 0.05;
        // std::vector<GradNode> grad_nodes;
        // GradCloud<Float> grad_cloud;
        // KDtree_3<Float> kdtree;

        MutationHybrid(){
        }

        void step_adam(const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M){
            if (first){
                first = false;
                G.resize(g.size());
                d.resize(g.size());
                G = g.cwiseProduct(g);
                d = g;
            } else {
                G = beta * G + (1 - beta) * g.cwiseProduct(g);
            }
            MALAVector H = delta * MALAVector::Ones(g.size()) + G.cwiseSqrt();
            M = H.cwiseInverse();
            d = alpha * d + (1 - alpha) * g;
            m = d + g;
        }

        void step_knn(const KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) {
            MALAVector G = MALAVector::Zero(g.size());
            cache.query(u, g, m, G, cam_bounce);

            MALAVector H = delta * MALAVector::Ones(g.size()) + G.cwiseSqrt();
            M = H.cwiseInverse();
        }

        void step(KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            if (cache.write) {
                step_adam(u, g, m, M);
                cache.add(u, g, g.cwiseProduct(g), cam_bounce);
            } else {
                step_knn(cache, u, g, m, M, cam_bounce);
            }
        }

        void step_hypo(KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            if (cache.write) {
                MALAVector G_ = G;
                MALAVector d_ = d;
                bool first_ = first; // shouldn't need this
                step_adam(u, g, m, M);
                G = G_;
                d = d_;
                first = first_;
            } else {
                step_knn(cache, u, g, m, M, cam_bounce);
            }
        }

        bool step_readonly(const KNNCache &cache, const MALAVector &u, const MALAVector &g, MALAVector &m, MALAVector &M, int cam_bounce) override{
            return cache.query(u, g, m, M, cam_bounce);
        }

        void setZero() override {
            first = true;
            G.setZero();
            d.setZero();
        }

        void merge(const MutationHybrid &other) {}
    };

    // const Vector3i dim_mutate = Vector3i(1, 1, 1); // variable for debugging
    
    struct Gaussian {
        MALAVector mean;
        Float logDet;
        // We only consider diagonal preconditioning for now.
        MALAVector covL_d;
        MALAVector invCov_d;

        MALAVector GenerateSample(RndSampler *sampler) const {
            // Vector2 u_0 = squareToGaussianDisk(sampler->next2D(), 1.0);
            // Vector2 u_1 = squareToGaussianDisk(sampler->next2D(), 1.0);
            // Vector x = Vector::Zero();
            // x[0] = u_0[0];
            // x[1] = u_0[1];
            // x[2] = u_1[0];
            // // wasting 1 random number
            // for (int i = 0; i < 3; i++) {
            //     if (dim_mutate[i] == 0) {
            //         x[i] = 0;
            //     } else {
            //         x[i] = x[i] * sqrt(covL_d[i]) + mean[i];
            //     }
            // }
            int dim = mean.size();
            MALAVector x = MALAVector::Zero(dim);
            for (int i = 0; i < (dim + 1) / 2; i++) {
                Vector2 u = squareToGaussianDisk(sampler->next2D(), 1.0);
                x[2 * i] = u[0];
                if (2 * i + 1 < dim) {
                    x[2 * i + 1] = u[1];
                }
            }
            for (int i = 0; i < dim; i++) {
                x[i] = x[i] * sqrt(covL_d[i]) + mean[i];
            }
            // x = covL_d.cwiseSqrt().cwiseProduct(x) + mean;
            return x;
        }

        Float GaussianLogPdf(const MALAVector &offset/*, const MALAVector &discrete_dim*/) const {
            auto d = offset - mean;
            // PSDR_INFO("d: {}, {}", d[0], d[1]);
            // Float logPdf = -log(2 * M_PI); 
            Float dim = Float(offset.size()); // - discrete_dim.sum();
            Float logPdf = dim * Float(-0.9189385332046727);// = (-Float(0.5) * log(Float(2.0 * M_PI)));
            logPdf -= Float(0.5) * logDet;


            for (int i = 0; i < dim; i++) {
                // if (discrete_dim[i] < 0.5) {
                    logPdf -= Float(0.5) * d[i] * d[i] * invCov_d[i];
                // }
                // logPdf -= Float(0.5) * d[i] * d[i] * invCov_d[i];
            }
            // logPdf -= Float(0.5) * d.dot(invCov_d.cwiseProduct(d));
            // log_pdf_uv = (-0.5 * offset_uv.dot((mala_config.step_length * M_u).cwiseInverse().cwiseProduct(offset_uv)))
            //                     - log(2 * M_PI) - 0.5 * log((mala_config.step_length * M_u).prod());
            return logPdf;
        }
    };

    void ComputeGaussian(const MALAVector &eps, const MALAVector &m, const MALAVector &M/*, const MALAVector &discrete_dim*/, Gaussian &gaussian) {
        assert(m.size() == M.size());
        gaussian.covL_d = eps.cwiseProduct(M);
        // PSDR_INFO("covL_d: {}, {}, {}", gaussian.covL_d[0], gaussian.covL_d[1], gaussian.covL_d[2]);
        // gaussian.covL_d[0] /= 10000;
        // gaussian.invCov_d = Float(1.0) / eps * M.cwiseInverse();
        gaussian.invCov_d = Float(1.0) / eps.cwiseProduct(M);
        // gaussian.invCov_d[0] *= 10000;
        gaussian.mean = 0.5 * gaussian.covL_d.cwiseProduct(m);
        Float Det = 1.0;
        for (int i = 0; i < M.size(); i++) {
            // if (discrete_dim[i] < 0.5) {
                
                Det *= eps[i] * M[i];
                // if (i == 0) {
                //     Det /= 10000;
                // }
            // }
        }
        gaussian.logDet = log(Det);
    }

}

#endif // MALA_UTILS_H