#pragma once
#ifndef AQ3D_H__
#define AQ3D_H__

#include <iostream>

#include <fstream>

#include <string>

#include <cmath>

#include <algorithm>

#include <core/fwd.h>

#include <core/stats.h>

#include <core/sampler.h>

#include <core/pmf.h>

#include <fstream>


#define LINEAR

#ifdef LINEAR


static void build_distrib(const Vector3i &dims, const Eigen::Array<Float, -1, 1> &data, DiscreteDistribution &sample_distrb)
{
    long long num_voxels = static_cast<long long>(dims[0])*dims[1]*dims[2];
    std::cout << "grid mean: " << data.mean() << std::endl;
    Float ratio = static_cast<Float>(num_voxels)/data.sum();
    sample_distrb.clear();
    if ( num_voxels > 1 ) {
        sample_distrb.reserve(num_voxels);
        for ( long long i = 0; i < num_voxels; ++i ) sample_distrb.append(data[i]*ratio);
        sample_distrb.normalize();
    }
};

namespace Grid3D_Sampling {
    struct grid3D_config {
        grid3D_config() {}
        grid3D_config(const Vector3i &dims, int spp) : dims(dims), spp(spp) {}
        Vector3i dims;
        int spp;
    };

    struct grid3D {
        DiscreteDistribution distrb;
        grid3D_config config;
        void setup(const std:: function < Float(const Vector3&, RndSampler&) > & func, const grid3D_config & cfg) {
            config = cfg;
            const int nworker = omp_get_num_procs();
            std::vector<RndSampler> samplers;
            for ( int i = 0; i < nworker; ++i ) samplers.push_back(RndSampler(13, i));
            Eigen::Array<Float, -1, 1> g_data(config.dims[0]*config.dims[1]*config.dims[2]);
            omp_lock_t messageLock;
            omp_init_lock(&messageLock);
            int blockProcessed = 0;


#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
            for ( long long omp_i = 0; omp_i < config.dims[0]*config.dims[1]; ++omp_i ) {
                const int tid = omp_get_thread_num();
                RndSampler &sampler = samplers[tid];
                const long long i = omp_i/config.dims[1], j = omp_i % config.dims[1];
                for (int k=0; k<config.dims[2]; ++k) {
                    Float res = 0.0f;
                    for (int t=0; t<config.spp; ++t) {
                        Vector rnd = sampler.next3D();
                        rnd[0] = (rnd[0] + i)/static_cast<Float>(config.dims[0]);
                        rnd[1] = (rnd[1] + j)/static_cast<Float>(config.dims[1]);
                        rnd[2] = (rnd[2] + k)/static_cast<Float>(config.dims[2]);
                        Float value = func(rnd, sampler);
                        res += value;
                    }
                    Float avg = res/static_cast<Float>(config.spp);
                    g_data[static_cast<long long>(omp_i)*config.dims[2] + k] = static_cast<Float>(avg);
                }
                {                
                    omp_set_lock(&messageLock);
                    progressIndicator(static_cast<Float>(++blockProcessed)/(config.dims[0]*config.dims[1]));
                    omp_unset_lock(&messageLock);
                }
            }
            std::cout << "grid max: " << g_data.maxCoeff() << std::endl;
            std::cout << "grid min: " << g_data.minCoeff() << std::endl;
            g_data = g_data/g_data.sum();
            g_data += 0.1*(g_data.maxCoeff()-g_data.minCoeff());

            build_distrib(config.dims, g_data, distrb);

            // std::cout << "save grid 3d to txt" << std::endl;
            // std::ofstream outFile_sample("grid_3d.txt");
            // for (int i=0; i<g_data.size(); ++i) {
            //     outFile_sample << g_data[i] << "\n";
            // }
        }


        Vector3 sample(Vector3 rnd, Float & pdf) const {
            long long idx = distrb.sampleReuse(rnd[2], pdf);
            int unit = config.dims[1]*config.dims[2];
            int i = static_cast<int>(idx/unit);
            idx %= unit;
            int j = static_cast<int>(idx/config.dims[2]), k = static_cast<int>(idx % config.dims[2]);
            pdf *= distrb.getSum();
            rnd = (Vector(i, j, k) + rnd).cwiseQuotient(Vector(config.dims[0], config.dims[1], config.dims[2]));

            assert(rnd[0] > -Epsilon && rnd[0] < 1.0f + Epsilon &&
                   rnd[1] > -Epsilon && rnd[1] < 1.0f + Epsilon &&
                   rnd[2] > -Epsilon && rnd[2] < 1.0f + Epsilon);
            assert(pdf > 0.0);
            return rnd;
        }
    };

};


class Linear_Regression3D
{
public:
    double b0;
    double b1;
    double b2;
    double b3;
    // y = b0 + b1*x1 + b2*x2 + b3*x3

    Linear_Regression3D() { reset(); }
    void reset()  { N=0;
        b0=b1=b2=b3=0.0;
        y_mean=x1_mean=x2_mean=x3_mean=0.0;
        x1_sum=x2_sum=x3_sum=y_sum=0.0;
        x1y_sum=x2y_sum=x3y_sum=0.0;
        x1x2_sum=x1x3_sum=x2x3_sum=0.0;
        x1x1_sum=x2x2_sum=x3x3_sum=0.0;
    }

    void push(Vector3 x, double y, int region) {
        double x1 = x[0];
        double x2 = x[1];
        double x3 = x[2];
        ++N;

        if (region == 1) {
            x1a_data.push_back(x);
            y1a_data.push_back(y);
            x2a_data.push_back(x);
            y2a_data.push_back(y);
            x3a_data.push_back(x);
            y3a_data.push_back(y);
        } else if (region == 2) {
            x1b_data.push_back(x);
            y1b_data.push_back(y);
            x2a_data.push_back(x);
            y2a_data.push_back(y);
            x3a_data.push_back(x);
            y3a_data.push_back(y);

        } else if (region == 3) {
            x1a_data.push_back(x);
            y1a_data.push_back(y);
            x2b_data.push_back(x);
            y2b_data.push_back(y);
            x3a_data.push_back(x);
            y3a_data.push_back(y);
        } else if (region == 4) {
            x1a_data.push_back(x);
            y1a_data.push_back(y);
            x2a_data.push_back(x);
            y2a_data.push_back(y);
            x3b_data.push_back(x);
            y3b_data.push_back(y);
        } else if (region == 5) {
            x1b_data.push_back(x);
            y1b_data.push_back(y);
            x2b_data.push_back(x);
            y2b_data.push_back(y);
            x3a_data.push_back(x);
            y3a_data.push_back(y);
        } else if (region == 6) {
            x1b_data.push_back(x);
            y1b_data.push_back(y);
            x2a_data.push_back(x);
            y2a_data.push_back(y);
            x3b_data.push_back(x);
            y3b_data.push_back(y);
        } else if (region == 7) {
            x1a_data.push_back(x);
            y1a_data.push_back(y);
            x2b_data.push_back(x);
            y2b_data.push_back(y);
            x3b_data.push_back(x);
            y3b_data.push_back(y);
        } else {
            x1b_data.push_back(x);
            y1b_data.push_back(y);
            x2b_data.push_back(x);
            y2b_data.push_back(y);
            x3b_data.push_back(x);
            y3b_data.push_back(y);
        }


        x1y_sum += x1*y;
        x2y_sum += x2*y;
        x3y_sum += x3*y;

        x1x2_sum += x1*x2;
        x1x3_sum += x1*x3;
        x2x3_sum += x2*x3;

        x1x1_sum += x1*x1;
        x2x2_sum += x2*x2;
        x3x3_sum += x3*x3;

        x1_sum += x1;
        x2_sum += x2;
        x3_sum += x3;
        y_sum += y;

        x1_mean += (x1 - x1_mean)/static_cast<double>(N);
        x2_mean += (x2 - x2_mean)/static_cast<double>(N);
        x3_mean += (x3 - x3_mean)/static_cast<double>(N);
        y_mean += (y - y_mean)/static_cast<double>(N);


        x1_data.push_back(x1);
        x2_data.push_back(x2);
        x3_data.push_back(x3);
        y_data.push_back(y);

    }

    void compute() {
        double SS_x1y  = x1y_sum - x1_sum * y_sum / static_cast<double>(N);
        double SS_x2y  = x2y_sum - x2_sum * y_sum / static_cast<double>(N);
        double SS_x3y  = x3y_sum - x3_sum * y_sum / static_cast<double>(N);

        double SS_x1x2 = x1x2_sum - x1_sum * x2_sum / static_cast<double>(N);
        double SS_x1x3 = x1x3_sum - x1_sum * x3_sum / static_cast<double>(N);
        double SS_x2x3 = x2x3_sum - x2_sum * x3_sum / static_cast<double>(N);

        double SS_x1x1 = x1x1_sum - x1_sum * x1_sum / static_cast<double>(N);
        double SS_x2x2 = x2x2_sum - x2_sum * x2_sum / static_cast<double>(N);
        double SS_x3x3 = x3x3_sum - x3_sum * x3_sum / static_cast<double>(N);

        Matrix3x3 SS_mat;
        SS_mat << SS_x1x1, SS_x1x2, SS_x1x3, SS_x1x2, SS_x2x2, SS_x2x3, SS_x1x3, SS_x2x3, SS_x3x3;
        Vector3   Y_mat(SS_x1y, SS_x2y, SS_x3y);
        Vector3 linear = SS_mat.inverse()*Y_mat;

        b3 = linear[2];
        b2 = linear[1]; 
        b1 = linear[0]; 
        b0 = y_mean - b1 * x1_mean - b2 * x2_mean - b3 * x3_mean;

        // b3 = 0;
        // b2 = 0; 
        // b1 = 0; 
        // b0 = y_mean - b1 * x1_mean - b2 * x2_mean - b3 * x3_mean;


#if 1
        double bound_1 = b0;
        if (bound_1 < 0.0) {
            bound_1 = 0.0;
        }
        double bound_2 = b0+b1;
        if (bound_2 < 0.0) {
            bound_2 = 0.0;
        }
        double bound_3 = b0+b2;
        if (bound_3 < 0.0) {
            bound_3 = 0.0;
        }

        double bound_4 = b0+b3;
        if (bound_4 < 0.0) {
            bound_4 = 0.0;
        }

        b0 = bound_1;
        b1 = bound_2 - bound_1;
        b2 = bound_3 - bound_1;
        b3 = bound_4 - bound_1;

        if (b0+b1+b2 < 0.0) {
            double neg = b0+b1+b2;
            bound_2 -= 0.5*neg;
            bound_3 -= 0.5*neg;
            b1 = bound_2 - bound_1;
            b2 = bound_3 - bound_1;
        }

        if (b0+b2+b3 < 0.0) {
            double neg = b0+b2+b3;
            bound_3 -= 0.5*neg;
            bound_4 -= 0.5*neg;
            b2 = bound_3 - bound_1;
            b3 = bound_4 - bound_1;
        }

        if (b0+b1+b3 < 0.0) {
            double neg = b0+b1+b3;
            bound_2 -= 0.5*neg;
            bound_4 -= 0.5*neg;
            b1 = bound_2 - bound_1;
            b3 = bound_4 - bound_1;
        }
        if (b0+b1+b2+b3 < 0.0) {
            double neg = b0+b1+b2+b3;
            bound_2 -= 1.0/3.0*neg;
            bound_3 -= 1.0/3.0*neg;
            bound_4 -= 1.0/3.0*neg;
            b1 = bound_2 - bound_1;
            b2 = bound_3 - bound_1;
            b3 = bound_4 - bound_1;
        }

        // b0 += 2.0;
#else
        double max_neg = 0.0;
        if (b0 < max_neg) {
            max_neg = b0;
        }
        if (b0+b1 < max_neg) {
            max_neg = b0+b1;
        }
        if (b0+b2 < max_neg) {
            max_neg = b0+b2;
        }
        if (b0+b3 < max_neg) {
            max_neg = b0+b3;
        }
        if (b0+b1+b2 < max_neg) {
            max_neg = b0+b1+b2;
        }
        if (b0+b1+b3 < max_neg) {
            max_neg = b0+b1+b3;
        }
        if (b0+b2+b3 < max_neg) {
            max_neg = b0+b2+b3;
        }
        if (b0+b1+b2+b3 < max_neg) {
            max_neg = b0+b1+b2+b3;
        }
        b0 -= max_neg;
#endif
    }
    

    double getRes() {
        double MSE = 0.0;
        for (int i=0; i<N; ++i) {
            double expect_y = b0+b1*x1_data[i]+b2*x2_data[i]+b3*x3_data[i];
            MSE += pow((expect_y-y_data[i]), 2) / static_cast<double>(N);
        }
        return sqrt(MSE);
    };

    int getDimRes() {
        // check first dimension
        double MSE_1a = 0.0;
        for (int i=0; i<x1a_data.size(); ++i) {
            double expect_y = b0+b1*x1a_data[i][0]+b2*x1a_data[i][1]+b3*x1a_data[i][2];
            MSE_1a += pow((expect_y-y1a_data[i]), 2) / static_cast<double>(x1a_data.size());
        }

        double MSE_1b = 0.0;
        for (int i=0; i<x1b_data.size(); ++i) {
            double expect_y = b0+b1*x1b_data[i][0]+b2*x1b_data[i][1]+b3*x1b_data[i][2];
            MSE_1b += pow((expect_y-y1b_data[i]), 2) / static_cast<double>(x1b_data.size());
        }

        double MSE_2a = 0.0;
        for (int i=0; i<x2a_data.size(); ++i) {
            double expect_y = b0+b1*x2a_data[i][0]+b2*x2a_data[i][1]+b3*x2a_data[i][2];
            MSE_2a += pow((expect_y-y2a_data[i]), 2) / static_cast<double>(x2a_data.size());
        }

        double MSE_2b = 0.0;
        for (int i=0; i<x2b_data.size(); ++i) {
            double expect_y = b0+b1*x2b_data[i][0]+b2*x2b_data[i][1]+b3*x2b_data[i][2];
            MSE_2b += pow((expect_y-y2b_data[i]), 2) / static_cast<double>(x2b_data.size());
        }

        double MSE_3a = 0.0;
        for (int i=0; i<x3a_data.size(); ++i) {
            double expect_y = b0+b1*x3a_data[i][0]+b2*x3a_data[i][1]+b3*x3a_data[i][2];
            MSE_3a += pow((expect_y-y3a_data[i]), 2) / static_cast<double>(x3a_data.size());
        }

        double MSE_3b = 0.0;
        for (int i=0; i<x3b_data.size(); ++i) {
            double expect_y = b0+b1*x3b_data[i][0]+b2*x3b_data[i][1]+b3*x3b_data[i][2];
            MSE_3b += pow((expect_y-y3b_data[i]), 2) / static_cast<double>(x3b_data.size());
        }

        double error1 = abs(MSE_1b - MSE_1a);
        double error2 = abs(MSE_2b - MSE_2a);
        double error3 = abs(MSE_3b - MSE_3a);

        if (error1 > error2 && error1 > error3) {
            return 1;
        } else if (error2 > error1 && error2 > error3) {
            return 2;
        } else if (error3 > error1 && error3 > error2) {
            return 3;
        } else {
            if (error1 > error2 || error1 > error3) {
                return 1;
            } else if (error2 > error1 || error2 > error3) {
                return 2;
            } else if (error3 > error1 || error3 > error2) {
                return 3;
            } else {
                return 0; // this is rare
            }
        }
    }

    int getDimRes(bool stop_x, bool stop_y, bool stop_z) {
        // check first dimension
        double MSE_1a = 0.0;
        for (int i=0; i<x1a_data.size(); ++i) {
            double expect_y = b0+b1*x1a_data[i][0]+b2*x1a_data[i][1]+b3*x1a_data[i][2];
            MSE_1a += pow((expect_y-y1a_data[i]), 2) / static_cast<double>(x1a_data.size());
        }

        double MSE_1b = 0.0;
        for (int i=0; i<x1b_data.size(); ++i) {
            double expect_y = b0+b1*x1b_data[i][0]+b2*x1b_data[i][1]+b3*x1b_data[i][2];
            MSE_1b += pow((expect_y-y1b_data[i]), 2) / static_cast<double>(x1b_data.size());
        }

        double MSE_2a = 0.0;
        for (int i=0; i<x2a_data.size(); ++i) {
            double expect_y = b0+b1*x2a_data[i][0]+b2*x2a_data[i][1]+b3*x2a_data[i][2];
            MSE_2a += pow((expect_y-y2a_data[i]), 2) / static_cast<double>(x2a_data.size());
        }

        double MSE_2b = 0.0;
        for (int i=0; i<x2b_data.size(); ++i) {
            double expect_y = b0+b1*x2b_data[i][0]+b2*x2b_data[i][1]+b3*x2b_data[i][2];
            MSE_2b += pow((expect_y-y2b_data[i]), 2) / static_cast<double>(x2b_data.size());
        }

        double MSE_3a = 0.0;
        for (int i=0; i<x3a_data.size(); ++i) {
            double expect_y = b0+b1*x3a_data[i][0]+b2*x3a_data[i][1]+b3*x3a_data[i][2];
            MSE_3a += pow((expect_y-y3a_data[i]), 2) / static_cast<double>(x3a_data.size());
        }

        double MSE_3b = 0.0;
        for (int i=0; i<x3b_data.size(); ++i) {
            double expect_y = b0+b1*x3b_data[i][0]+b2*x3b_data[i][1]+b3*x3b_data[i][2];
            MSE_3b += pow((expect_y-y3b_data[i]), 2) / static_cast<double>(x3b_data.size());
        }

        double error1 = abs(MSE_1b - MSE_1a);
        double error2 = abs(MSE_2b - MSE_2a);
        double error3 = abs(MSE_3b - MSE_3a);

        if (stop_x) {
            error1 = 0.0;
        }
        if (stop_y) {
            error2 = 0.0;
        }
        if (stop_z) {
            error3 = 0.0;
        }

        if (error1 > error2 && error1 > error3) {
            return 1;
        } else if (error2 > error1 && error2 > error3) {
            return 2;
        } else if (error3 > error1 && error3 > error2) {
            return 3;
        } else {
            if (error1 > error2 || error1 > error3) {
                return 1;
            } else if (error2 > error1 || error2 > error3) {
                return 2;
            } else if (error3 > error1 || error3 > error2) {
                return 3;
            } else {
                return 0; // this is rare
            }
        }
    }


public:
    int N;

    double y_mean;
    double x1_mean;
    double x2_mean;
    double x3_mean;

    double x1_sum;
    double x2_sum;
    double x3_sum;
    double y_sum;

    double x1y_sum;
    double x2y_sum;
    double x3y_sum;

    double x1x2_sum;
    double x1x3_sum;
    double x2x3_sum;
    double x1x1_sum;
    double x2x2_sum;
    double x3x3_sum;

    std::vector<Vector3> x1a_data;
    std::vector<Vector3> x1b_data;

    std::vector<Vector3> x2a_data;
    std::vector<Vector3> x2b_data;

    std::vector<Vector3> x3a_data;
    std::vector<Vector3> x3b_data;

    std::vector<double> y1a_data;
    std::vector<double> y1b_data;
    std::vector<double> y2a_data;
    std::vector<double> y2b_data;
    std::vector<double> y3a_data;
    std::vector<double> y3b_data;

    std::vector<double> x1_data;
    std::vector<double> x2_data;
    std::vector<double> x3_data;
    std::vector<double> y_data;

};

#else

class Uniform_Regression3D
{
public:
    double b0;
    double b1;
    double b2;
    double b3;
    // y = b0 + b1*x1 + b2*x2 + b3*x3

    Uniform_Regression3D() { reset(); }
    void reset()  {
        xa.reset();
        xb.reset();
        ya.reset();
        yb.reset();
        za.reset();
        zb.reset();
        data.reset();

    }

    void push(Vector3 x, double y, int region) {
        data.push(y);
        if (region == 1) {
            xa.push(y);
            ya.push(y);
            za.push(y);
        } else if (region == 2) {
            xb.push(y);
            ya.push(y);
            za.push(y);
        } else if (region == 3) {
            xa.push(y);
            yb.push(y);
            za.push(y);
        } else if (region == 4) {
            xa.push(y);
            ya.push(y);
            zb.push(y);
        } else if (region == 5) {
            xb.push(y);
            yb.push(y);
            za.push(y);
        } else if (region == 6) {
            xb.push(y);
            ya.push(y);
            zb.push(y);
        } else if (region == 7) {
            xa.push(y);
            yb.push(y);
            zb.push(y);
        } else {
            xb.push(y);
            yb.push(y);
            zb.push(y);
        }
    }

    void compute() {
        b1 = b2 = b3 = 0.0;
        b0 = data.getMean();
    }
    

    double getRes() {
        return data.getVar();
    };

    int getDimRes() {
        // check first dimension
        double MSE_1a = xa.getVar();
        double MSE_1b = xb.getVar();
        double MSE_2a = ya.getVar();
        double MSE_2b = yb.getVar();
        double MSE_3a = za.getVar();
        double MSE_3b = zb.getVar();

        double error1 = abs(MSE_1b - MSE_1a);
        double error2 = abs(MSE_2b - MSE_2a);
        double error3 = abs(MSE_3b - MSE_3a);


        if (error1 > error2 && error1 > error3) {
            return 1;
        } else if (error2 > error1 && error2 > error3) {
            return 2;
        } else if (error3 > error1 && error3 > error2) {
            return 3;
        } else {
            if (error1 > error2 || error1 > error3) {
                return 1;
            } else if (error2 > error1 || error2 > error3) {
                return 2;
            } else if (error3 > error1 || error3 > error2) {
                return 3;
            } else {
                return 0; // this is rare
            }
        }
    }

    int getDimRes(bool stop_x, bool stop_y, bool stop_z) {
        // check first dimension
        double MSE_1a = xa.getVar();
        double MSE_1b = xb.getVar();
        double MSE_2a = ya.getVar();
        double MSE_2b = yb.getVar();
        double MSE_3a = za.getVar();
        double MSE_3b = zb.getVar();

        double error1 = abs(MSE_1b - MSE_1a);
        double error2 = abs(MSE_2b - MSE_2a);
        double error3 = abs(MSE_3b - MSE_3a);

        if (stop_x) {
            error1 = 0.0;
        }
        if (stop_y) {
            error2 = 0.0;
        }
        if (stop_z) {
            error3 = 0.0;
        }

        if (error1 > error2 && error1 > error3) {
            return 1;
        } else if (error2 > error1 && error2 > error3) {
            return 2;
        } else if (error3 > error1 && error3 > error2) {
            return 3;
        } else {
            if (error1 > error2 || error1 > error3) {
                return 1;
            } else if (error2 > error1 || error2 > error3) {
                return 2;
            } else if (error3 > error1 || error3 > error2) {
                return 3;
            } else {
                return 0; // this is rare
            }
        }
    }


    Statistics xa;
    Statistics xb;
    Statistics ya;
    Statistics yb;
    Statistics za;
    Statistics zb;
    Statistics data;


};
#endif
namespace Adaptive_Sampling {
    struct Tree_3D {
        Vector3 left;
        Vector3 right;
        Vector4 para;
        Float error;
        int dim; // precompute which dimension to cut
        int depth = 1;

        int depth_x = 1;
        int depth_y = 1;
        int depth_z = 1;


        inline bool operator < (Tree_3D a) {
            if (a.error > error)
                return true;
            else
                return false;
        }
    };

    struct adaptive3D_config {
        adaptive3D_config () {}
        adaptive3D_config (double thold, int spg, double weight_decay, int max_depth, int npass) : thold(thold), spg(spg), weight_decay(weight_decay), max_depth(max_depth), npass(npass) {}
        double thold = 0.0001;
        int spg = 16;
        int min_spg = 4;
        double weight_decay = 0.5;
        double sample_decay = 0.5;
        int max_depth = 32;
        int npass = 10;
        bool use_heap = true;
        bool edge_draw = false;

        int max_depth_x = 8;
        int max_depth_y = 16;
        int max_depth_z = 16;

        double eps = 0.1;

        int shape_opt_id = -1;
        bool local_backward = false;
    };


    struct adaptive3D {
        std::vector < Tree_3D > tree_node;
        DiscreteDistribution distrb;

        void BFS(const std::function<Float(const Vector3&, RndSampler&)> &func, const Tree_3D &temp_tree, const adaptive3D_config &config, RndSampler& rnd) {
            if (temp_tree.error <= config.thold || temp_tree.depth > config.max_depth) {
                #pragma omp critical 
                {
                    tree_node.push_back(temp_tree);
                }
            } else {
                Tree_3D left_tree;
                Tree_3D right_tree;
                Float temp_mid;
                int cut_dim = temp_tree.dim;
                if (cut_dim == 0) {
                    // std::cout << "Failed to calculate which dimension with error: " << temp_tree.error << std::endl;
                    cut_dim = 3;
                }
                if (cut_dim == 1) {
                    temp_mid = (temp_tree.left[0] + temp_tree.right[0]) / 2.0;
                    left_tree.right = Vector3(temp_mid, temp_tree.right[1], temp_tree.right[2]);
                    right_tree.left = Vector3(temp_mid, temp_tree.left[1], temp_tree.left[2]);
                } else if (cut_dim == 2) {
                    temp_mid = (temp_tree.left[1] + temp_tree.right[1]) / 2.0;
                    left_tree.right = Vector3(temp_tree.right[0], temp_mid, temp_tree.right[2]);
                    right_tree.left = Vector3(temp_tree.left[0], temp_mid, temp_tree.left[2]);
                } else {
                    temp_mid = (temp_tree.left[2] + temp_tree.right[2]) / 2.0;
                    left_tree.right = Vector3(temp_tree.right[0], temp_tree.right[1], temp_mid);
                    right_tree.left = Vector3(temp_tree.left[0], temp_tree.left[1], temp_mid);
                }
                left_tree.left = temp_tree.left;
                right_tree.right = temp_tree.right;

#ifdef LINEAR
                Linear_Regression3D eval_linear1;
                Linear_Regression3D eval_linear2;
#else
                Uniform_Regression3D eval_linear1;
                Uniform_Regression3D eval_linear2;
#endif

                int current_spp = config.spg * pow(config.sample_decay, temp_tree.depth);
                if (current_spp < config.min_spg) {
                    current_spp = config.min_spg;
                }

                for (int omp_i = 0; omp_i < current_spp*8; ++omp_i) {
                    int loop_count = omp_i / 8;
                    const int tid = omp_get_thread_num();
                    int get_region = omp_i % 8 + 1;

                    Vector3 rnd_val1 = rnd.next3D();
                    Vector3 rnd_val2 = rnd.next3D();
                    Vector3 rnd_val_region1;
                    Vector3 rnd_val_region2;
                    Vector3 rnd_val_11 = rnd_val1 * 0.5;
                    Vector3 rnd_val_12 = rnd_val2 * 0.5;
                    if (get_region == 1) {
                        rnd_val_region1 = rnd_val_11;
                        rnd_val_region2 = rnd_val_12;
                    } else if (get_region == 2) {
                        rnd_val_region1 = Vector3(rnd_val_11[0] + 0.5, rnd_val_11[1], rnd_val_11[2]);
                        rnd_val_region2 = Vector3(rnd_val_12[0] + 0.5, rnd_val_12[1], rnd_val_12[2]);
                    } else if (get_region == 3) {
                        rnd_val_region1 = Vector3(rnd_val_11[0], rnd_val_11[1] + 0.5, rnd_val_11[2]);
                        rnd_val_region2 = Vector3(rnd_val_12[0], rnd_val_12[1] + 0.5, rnd_val_12[2]);
                    } else if (get_region == 4) {
                        rnd_val_region1 = Vector3(rnd_val_11[0], rnd_val_11[1], rnd_val_11[2] + 0.5);
                        rnd_val_region2 = Vector3(rnd_val_12[0], rnd_val_12[1], rnd_val_12[2] + 0.5);
                    } else if (get_region == 5) {
                        rnd_val_region1 = Vector3(rnd_val_11[0] + 0.5, rnd_val_11[1] + 0.5, rnd_val_11[2]);
                        rnd_val_region2 = Vector3(rnd_val_12[0] + 0.5, rnd_val_12[1] + 0.5, rnd_val_12[2]);
                    } else if (get_region == 6) {
                        rnd_val_region1 = Vector3(rnd_val_11[0] + 0.5, rnd_val_11[1], rnd_val_11[2] + 0.5);
                        rnd_val_region2 = Vector3(rnd_val_12[0] + 0.5, rnd_val_12[1], rnd_val_12[2] + 0.5);
                    } else if (get_region == 7) {
                        rnd_val_region1 = Vector3(rnd_val_11[0], rnd_val_11[1] + 0.5, rnd_val_11[2] + 0.5);
                        rnd_val_region2 = Vector3(rnd_val_12[0], rnd_val_12[1] + 0.5, rnd_val_12[2] + 0.5);
                    } else if (get_region == 8) {
                        rnd_val_region1 = Vector3(rnd_val_11[0] + 0.5, rnd_val_11[1] + 0.5, rnd_val_11[2] + 0.5);
                        rnd_val_region2 = Vector3(rnd_val_12[0] + 0.5, rnd_val_12[1] + 0.5, rnd_val_12[2] + 0.5);
                    }

                    Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd_val_region1[0]),
                        (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd_val_region1[1]),
                        (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd_val_region1[2])), rnd);

                    Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd_val_region2[0]),
                        (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd_val_region2[1]),
                        (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd_val_region2[2])), rnd);
                    eval_linear1.push(rnd_val_region1, temp1, get_region);
                    eval_linear2.push(rnd_val_region2, temp2, get_region);
                }
                eval_linear1.compute();
                left_tree.para = Vector4(eval_linear1.b0, eval_linear1.b1, eval_linear1.b2, eval_linear1.b3);

                eval_linear2.compute();
                right_tree.para = Vector4(eval_linear2.b0, eval_linear2.b1, eval_linear2.b2, eval_linear2.b3);

                left_tree.depth = temp_tree.depth + 1;
                right_tree.depth = temp_tree.depth + 1;

                left_tree.error = eval_linear1.getRes() * pow(config.weight_decay, left_tree.depth);
                right_tree.error = eval_linear2.getRes() * pow(config.weight_decay, left_tree.depth);
                left_tree.dim = eval_linear1.getDimRes();
                right_tree.dim = eval_linear2.getDimRes();

                BFS(func, left_tree, config, rnd);
                BFS(func, right_tree, config, rnd);
            }
        }

        void setup(const std:: function < Float(const Vector3&, RndSampler&) > & func,
            const std::vector < Float > & cdfx,
                const adaptive3D_config & config) {

            if (config.use_heap) {
                std::cout << "adaptive3D setup_heap" << std::endl;
                setup_heap(func,cdfx,config);
            } else {
                std::cout << "adaptive3D setup_thold" << std::endl;
                setup_thold(func,cdfx,config);
            }
            std::cout << "Final aq size: " << distrb.size() << std::endl;
            std::cout << "Final aq sum: " << distrb.m_sum << std::endl;

        }


        void setup_thold(const std:: function < Float(const Vector3&, RndSampler&) > & func,
            const std::vector < Float > & cdfx,
                const adaptive3D_config & config) {
            tree_node.clear();
            distrb.clear();

            const int nworker = omp_get_num_procs();
            std::vector < RndSampler > samplers;
            for (int i = 0; i < nworker; ++i) samplers.push_back(RndSampler(13, i));

            omp_lock_t messageLock;
            omp_init_lock(&messageLock);

            int blockProcessed = 0;

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
            for (int omp_i = 0; omp_i < cdfx.size() - 1; ++omp_i) {
                const int tid = omp_get_thread_num();
                RndSampler & sampler = samplers[tid];
                Tree_3D curr_tree;
                curr_tree.left = Vector3(cdfx[omp_i], 0.0, 0.0);
                curr_tree.right = Vector3(cdfx[omp_i + 1], 1.0, 1.0);

#ifdef LINEAR
                Linear_Regression3D eval_linear;
#else
                Uniform_Regression3D eval_linear;
#endif

                for (int gg = 0; gg < config.spg*8; ++gg) {
                    int get_region = omp_i % 8 + 1;
                    Vector3 rnd_val = sampler.next3D();
                    Vector3 rnd_val_region;
                    Vector3 rnd_val_1 = rnd_val * 0.5;
                    if (get_region == 1) {
                        rnd_val_region = rnd_val_1;
                    } else if (get_region == 2) {
                        rnd_val_region = Vector3(rnd_val_1[0] + 0.5, rnd_val_1[1], rnd_val_1[2]);
                    } else if (get_region == 3) {
                        rnd_val_region = Vector3(rnd_val_1[0], rnd_val_1[1] + 0.5, rnd_val_1[2]);
                    } else if (get_region == 4) {
                        rnd_val_region = Vector3(rnd_val_1[0], rnd_val_1[1], rnd_val_1[2] + 0.5);
                    } else if (get_region == 5) {
                        rnd_val_region = Vector3(rnd_val_1[0] + 0.5, rnd_val_1[1] + 0.5, rnd_val_1[2]);
                    } else if (get_region == 6) {
                        rnd_val_region = Vector3(rnd_val_1[0] + 0.5, rnd_val_1[1], rnd_val_1[2] + 0.5);
                    } else if (get_region == 7) {
                        rnd_val_region = Vector3(rnd_val_1[0], rnd_val_1[1] + 0.5, rnd_val_1[2] + 0.5);
                    } else if (get_region == 8) {
                        rnd_val_region = Vector3(rnd_val_1[0] + 0.5, rnd_val_1[1] + 0.5, rnd_val_1[2] + 0.5);
                    }
                    Vector3 repara = Vector3(curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd_val_region[0],
                        curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd_val_region[1],
                        curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd_val_region[2]);
                    Float temp = func(repara, sampler);
                    eval_linear.push(rnd_val_region, temp, get_region);
                }
                eval_linear.compute();
                curr_tree.para = Vector4(eval_linear.b0, eval_linear.b1, eval_linear.b2, eval_linear.b3);
                curr_tree.error = eval_linear.getRes();
                curr_tree.dim = eval_linear.getDimRes();

                BFS(func, curr_tree, config, sampler);

                {                
                    omp_set_lock(&messageLock);
                    progressIndicator(static_cast<Float>(++blockProcessed)/cdfx.size());
                    omp_unset_lock(&messageLock);
                }


            }

            omp_destroy_lock(&messageLock);

            for (int i = 0; i < tree_node.size(); ++i) {
                double mean_val = tree_node[i].para[0] + tree_node[i].para[1] * 0.5 +
                    tree_node[i].para[2] * 0.5 + tree_node[i].para[3] * 0.5;
                // mean_val += config.eps;
                double value = mean_val * (tree_node[i].right[0] - tree_node[i].left[0]) *
                    (tree_node[i].right[1] - tree_node[i].left[1]) *
                    (tree_node[i].right[2] - tree_node[i].left[2]);
                distrb.append(value);
            }
            distrb.normalize();
        }

        void setup_heap(const std:: function < Float(const Vector3&, RndSampler&) > & func,
            const std::vector < Float > & cdfx,
                const adaptive3D_config & config) {
            tree_node.clear();
            distrb.clear();

            const int nworker = omp_get_num_procs();
            std::vector < RndSampler > samplers;
            for (int i = 0; i < nworker; ++i) samplers.push_back(RndSampler(3333, i));

            omp_lock_t messageLock;
            omp_init_lock(&messageLock);
            int blockProcessed = 0;
            int non_zero = 0;

#pragma omp parallel for num_threads(nworker) schedule(dynamic, 1)
            for (int omp_i = 0; omp_i < cdfx.size() - 1; ++omp_i) {
                std::vector < Tree_3D > current_curve_heap;

                const int tid = omp_get_thread_num();
                RndSampler & sampler = samplers[tid];
                Tree_3D curr_tree;
                curr_tree.left = Vector3(cdfx[omp_i], 0.0, 0.0);
                curr_tree.right = Vector3(cdfx[omp_i + 1], 1.0, 1.0);

#ifdef LINEAR
                Linear_Regression3D eval_linear;
#else
                Uniform_Regression3D eval_linear;
#endif

                for (int gg = 0; gg < config.spg; ++gg) {
                        Vector3 rnd1 = sampler.next3D();
                        Vector3 rnd1_1 = rnd1 * 0.5;
                        Vector3 rnd1_2(rnd1_1[0] + 0.5, rnd1_1[1], rnd1_1[2]);
                        Vector3 rnd1_3(rnd1_1[0], rnd1_1[1] + 0.5, rnd1_1[2]);
                        Vector3 rnd1_4(rnd1_1[0], rnd1_1[1], rnd1_1[2] + 0.5);

                        Vector3 rnd1_5(rnd1_1[0] + 0.5, rnd1_1[1] + 0.5, rnd1_1[2]);
                        Vector3 rnd1_6(rnd1_1[0] + 0.5, rnd1_1[1], rnd1_1[2] + 0.5);
                        Vector3 rnd1_7(rnd1_1[0], rnd1_1[1] + 0.5, rnd1_1[2] + 0.5);
                        Vector3 rnd1_8(rnd1_1[0] + 0.5, rnd1_1[1] + 0.5, rnd1_1[2] + 0.5);
                        {
                            Float temp1 = func(Vector3((curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd1_1[0]),
                                (curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd1_1[1]),
                                (curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd1_1[2])), sampler);

                            eval_linear.push(rnd1_1, temp1, 1);
                        }
                        {
                            Float temp1 = func(Vector3((curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd1_2[0]),
                                (curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd1_2[1]),
                                (curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd1_2[2])), sampler);

                            eval_linear.push(rnd1_2, temp1, 2);
                        }

                        {
                            Float temp1 = func(Vector3((curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd1_3[0]),
                                (curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd1_3[1]),
                                (curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd1_3[2])), sampler);

                            eval_linear.push(rnd1_3, temp1, 3);
                        }

                        {
                            Float temp1 = func(Vector3((curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd1_4[0]),
                                (curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd1_4[1]),
                                (curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd1_4[2])), sampler);

                            eval_linear.push(rnd1_4, temp1, 4);
                        }

                        {
                            Float temp1 = func(Vector3((curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd1_5[0]),
                                (curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd1_5[1]),
                                (curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd1_5[2])), sampler);

                            eval_linear.push(rnd1_5, temp1, 5);
                        }

                        {
                            Float temp1 = func(Vector3((curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd1_6[0]),
                                (curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd1_6[1]),
                                (curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd1_6[2])), sampler);

                            eval_linear.push(rnd1_6, temp1, 6);
                        }

                        {
                            Float temp1 = func(Vector3((curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd1_7[0]),
                                (curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd1_7[1]),
                                (curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd1_7[2])), sampler);
                            eval_linear.push(rnd1_7, temp1, 7);
                        }

                        {
                            Float temp1 = func(Vector3((curr_tree.left[0] + (curr_tree.right[0] - curr_tree.left[0]) * rnd1_8[0]),
                                (curr_tree.left[1] + (curr_tree.right[1] - curr_tree.left[1]) * rnd1_8[1]),
                                (curr_tree.left[2] + (curr_tree.right[2] - curr_tree.left[2]) * rnd1_8[2])), sampler);

                            eval_linear.push(rnd1_8, temp1, 8);
                        }

                }
                eval_linear.compute();
                curr_tree.para = Vector4(eval_linear.b0, eval_linear.b1, eval_linear.b2, eval_linear.b3);
                curr_tree.error = eval_linear.getRes();
                bool stop_x = false;
                bool stop_y = false;
                bool stop_z = false;
                if (curr_tree.depth_x > config.max_depth_x) {
                    stop_x = true;
                }
                if (curr_tree.depth_y > config.max_depth_y) {
                    stop_y = true;
                } 
                if (curr_tree.depth_z > config.max_depth_z) {
                    stop_z = true;
                } 

                curr_tree.dim = eval_linear.getDimRes(stop_x, stop_y, stop_z);
                if (curr_tree.dim == 1) {
                    curr_tree.depth_x++;
                } else if (curr_tree.dim == 2) {
                    curr_tree.depth_y++;
                } else if (curr_tree.dim == 3) {
                    curr_tree.depth_z++;
                }

                if (curr_tree.error > 0.0) {
                    non_zero++;
                }
                current_curve_heap.push_back(curr_tree);
                std::make_heap(current_curve_heap.begin(), current_curve_heap.end());


                for (int i = 0; i < config.npass; ++i) { // max cut per curvature
                    Tree_3D temp_tree = current_curve_heap.front();
                    if (temp_tree.error <= config.thold || temp_tree.depth > config.max_depth) {
                        break;
                    }
                    std::pop_heap(current_curve_heap.begin(), current_curve_heap.end());
                    current_curve_heap.pop_back();
                    
                    Tree_3D left_tree;
                    Tree_3D right_tree;
                    Float temp_mid;
                    if (temp_tree.dim == 0) {
                        // std::cout << "Failed to calculate which dimension with error: " << temp_tree.error << std::endl;
                        // WRAN: for direct we suggest to use dim=1 (unless mutiple emitter with sample reuse)
                        // WARN: for indirect we suggest to use dim=3
                        if (curr_tree.depth_z < curr_tree.depth_y) {
                            temp_tree.dim = 3;
                        } else {
                            temp_tree.dim = 2;
                        }
                    }
                    if (temp_tree.dim == 1) {
                        temp_mid = (temp_tree.left[0] + temp_tree.right[0]) / 2.0;
                        left_tree.right = Vector3(temp_mid, temp_tree.right[1], temp_tree.right[2]);
                        right_tree.left = Vector3(temp_mid, temp_tree.left[1], temp_tree.left[2]);
                    } else if (temp_tree.dim == 2) {
                        temp_mid = (temp_tree.left[1] + temp_tree.right[1]) / 2.0;
                        left_tree.right = Vector3(temp_tree.right[0], temp_mid, temp_tree.right[2]);
                        right_tree.left = Vector3(temp_tree.left[0], temp_mid, temp_tree.left[2]);
                    } else {
                        temp_mid = (temp_tree.left[2] + temp_tree.right[2]) / 2.0;
                        left_tree.right = Vector3(temp_tree.right[0], temp_tree.right[1], temp_mid);
                        right_tree.left = Vector3(temp_tree.left[0], temp_tree.left[1], temp_mid);
                    }
                    left_tree.left = temp_tree.left;
                    right_tree.right = temp_tree.right;

#ifdef LINEAR
                    Linear_Regression3D eval_linear1;
                    Linear_Regression3D eval_linear2;
#else
                    Uniform_Regression3D eval_linear1;
                    Uniform_Regression3D eval_linear2;
#endif
                    int current_spp = config.spg * pow(config.sample_decay, temp_tree.depth);
                    if (current_spp < config.min_spg) {
                        current_spp = config.min_spg;
                    }
                    for (int c = 0; c < current_spp; ++c) {
                        Vector3 rnd1 = sampler.next3D();
                        Vector3 rnd2 = sampler.next3D();

                        Vector3 rnd1_1 = rnd1 * 0.5;
                        Vector3 rnd1_2(rnd1_1[0] + 0.5, rnd1_1[1], rnd1_1[2]);
                        Vector3 rnd1_3(rnd1_1[0], rnd1_1[1] + 0.5, rnd1_1[2]);
                        Vector3 rnd1_4(rnd1_1[0], rnd1_1[1], rnd1_1[2] + 0.5);

                        Vector3 rnd1_5(rnd1_1[0] + 0.5, rnd1_1[1] + 0.5, rnd1_1[2]);
                        Vector3 rnd1_6(rnd1_1[0] + 0.5, rnd1_1[1], rnd1_1[2] + 0.5);
                        Vector3 rnd1_7(rnd1_1[0], rnd1_1[1] + 0.5, rnd1_1[2] + 0.5);
                        Vector3 rnd1_8(rnd1_1[0] + 0.5, rnd1_1[1] + 0.5, rnd1_1[2] + 0.5);

                        Vector3 rnd2_1 = rnd2 * 0.5;
                        Vector3 rnd2_2(rnd2_1[0] + 0.5, rnd2_1[1], rnd2_1[2]);
                        Vector3 rnd2_3(rnd2_1[0], rnd2_1[1] + 0.5, rnd2_1[2]);
                        Vector3 rnd2_4(rnd2_1[0], rnd2_1[1], rnd2_1[2] + 0.5);

                        Vector3 rnd2_5(rnd2_1[0] + 0.5, rnd2_1[1] + 0.5, rnd2_1[2]);
                        Vector3 rnd2_6(rnd2_1[0] + 0.5, rnd2_1[1], rnd2_1[2] + 0.5);
                        Vector3 rnd2_7(rnd2_1[0], rnd2_1[1] + 0.5, rnd2_1[2] + 0.5);
                        Vector3 rnd2_8(rnd2_1[0] + 0.5, rnd2_1[1] + 0.5, rnd2_1[2] + 0.5);

                        {
                            Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd1_1[0]),
                                (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd1_1[1]),
                                (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd1_1[2])), sampler);

                            Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd2_1[0]),
                                (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd2_1[1]),
                                (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd2_1[2])), sampler);
                            eval_linear1.push(rnd1_1, temp1, 1);
                            eval_linear2.push(rnd2_1, temp2, 1);
                        }
                        {
                            Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd1_2[0]),
                                (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd1_2[1]),
                                (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd1_2[2])), sampler);

                            Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd2_2[0]),
                                (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd2_2[1]),
                                (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd2_2[2])), sampler);
                            eval_linear1.push(rnd1_2, temp1, 2);
                            eval_linear2.push(rnd2_2, temp2, 2);
                        }

                        {
                            Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd1_3[0]),
                                (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd1_3[1]),
                                (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd1_3[2])), sampler);

                            Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd2_3[0]),
                                (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd2_3[1]),
                                (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd2_3[2])), sampler);
                            eval_linear1.push(rnd1_3, temp1, 3);
                            eval_linear2.push(rnd2_3, temp2, 3);
                        }

                        {
                            Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd1_4[0]),
                                (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd1_4[1]),
                                (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd1_4[2])), sampler);

                            Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd2_4[0]),
                                (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd2_4[1]),
                                (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd2_4[2])), sampler);
                            eval_linear1.push(rnd1_4, temp1, 4);
                            eval_linear2.push(rnd2_4, temp2, 4);
                        }

                        {
                            Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd1_5[0]),
                                (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd1_5[1]),
                                (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd1_5[2])), sampler);

                            Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd2_5[0]),
                                (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd2_5[1]),
                                (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd2_5[2])), sampler);
                            eval_linear1.push(rnd1_5, temp1, 5);
                            eval_linear2.push(rnd2_5, temp2, 5);
                        }

                        {
                            Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd1_6[0]),
                                (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd1_6[1]),
                                (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd1_6[2])), sampler);

                            Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd2_6[0]),
                                (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd2_6[1]),
                                (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd2_6[2])), sampler);
                            eval_linear1.push(rnd1_6, temp1, 6);
                            eval_linear2.push(rnd2_6, temp2, 6);
                        }

                        {
                            Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd1_7[0]),
                                (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd1_7[1]),
                                (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd1_7[2])), sampler);

                            Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd2_7[0]),
                                (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd2_7[1]),
                                (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd2_7[2])), sampler);
                            eval_linear1.push(rnd1_7, temp1, 7);
                            eval_linear2.push(rnd2_7, temp2, 7);
                        }

                        {
                            Float temp1 = func(Vector3((left_tree.left[0] + (left_tree.right[0] - left_tree.left[0]) * rnd1_8[0]),
                                (left_tree.left[1] + (left_tree.right[1] - left_tree.left[1]) * rnd1_8[1]),
                                (left_tree.left[2] + (left_tree.right[2] - left_tree.left[2]) * rnd1_8[2])), sampler);

                            Float temp2 = func(Vector3((right_tree.left[0] + (right_tree.right[0] - right_tree.left[0]) * rnd2_8[0]),
                                (right_tree.left[1] + (right_tree.right[1] - right_tree.left[1]) * rnd2_8[1]),
                                (right_tree.left[2] + (right_tree.right[2] - right_tree.left[2]) * rnd2_8[2])), sampler);
                            eval_linear1.push(rnd1_8, temp1, 8);
                            eval_linear2.push(rnd2_8, temp2, 8);
                        }
                    }
                    eval_linear1.compute();
                    left_tree.para = Vector4(eval_linear1.b0, eval_linear1.b1, eval_linear1.b2, eval_linear1.b3);

                    eval_linear2.compute();
                    right_tree.para = Vector4(eval_linear2.b0, eval_linear2.b1, eval_linear2.b2, eval_linear2.b3);

                    left_tree.depth = temp_tree.depth + 1;
                    right_tree.depth = temp_tree.depth + 1;

                    left_tree.error = eval_linear1.getRes() * pow(config.weight_decay, left_tree.depth);
                    right_tree.error = eval_linear2.getRes() * pow(config.weight_decay, left_tree.depth);
                    
                    if (left_tree.depth > config.max_depth) {
                        left_tree.error = 0.0;
                    }
                    if (right_tree.depth > config.max_depth) {
                        right_tree.error = 0.0;
                    }

                    if (left_tree.depth_x > config.max_depth_x &&
                        left_tree.depth_y > config.max_depth_y &&
                        left_tree.depth_z > config.max_depth_z) {
                        left_tree.error = 0.0;
                    }
                    if (right_tree.depth_x > config.max_depth_x &&
                        right_tree.depth_y > config.max_depth_y &&
                        right_tree.depth_z > config.max_depth_z) {
                        right_tree.error = 0.0;
                    }


                    bool stop_x1 = false;
                    bool stop_y1 = false;
                    bool stop_z1 = false;
                    if (left_tree.depth_x > config.max_depth_x) {
                        stop_x1 = true;
                    }
                    if (left_tree.depth_y > config.max_depth_y) {
                        stop_y1 = true;
                    } 
                    if (left_tree.depth_z > config.max_depth_z) {
                        stop_z1 = true;
                    }

                    bool stop_x2 = false;
                    bool stop_y2 = false;
                    bool stop_z2 = false;
                    if (right_tree.depth_x > config.max_depth_x) {
                        stop_x2 = true;
                    }
                    if (right_tree.depth_y > config.max_depth_y) {
                        stop_y2 = true;
                    } 
                    if (right_tree.depth_z > config.max_depth_z) {
                        stop_z2 = true;
                    } 

                    left_tree.dim = eval_linear1.getDimRes(stop_x1, stop_y1, stop_z1);
                    right_tree.dim = eval_linear2.getDimRes(stop_x2, stop_y2, stop_z2);

                    if (left_tree.dim == 1) {
                        left_tree.depth_x++;
                    } else if (left_tree.dim == 2) {
                        left_tree.depth_y++;
                    } else if (left_tree.dim == 3) {
                        left_tree.depth_z++;
                    }

                    if (right_tree.dim == 1) {
                        right_tree.depth_x++;
                    } else if (right_tree.dim == 2) {
                        right_tree.depth_y++;
                    } else if (right_tree.dim == 3) {
                        right_tree.depth_z++;
                    }

                    current_curve_heap.push_back(left_tree);
                    std::push_heap(current_curve_heap.begin(), current_curve_heap.end());
                    current_curve_heap.push_back(right_tree);
                    std::push_heap(current_curve_heap.begin(), current_curve_heap.end());

                }
                #pragma omp critical 
                {
                    tree_node.insert(std::end(tree_node), std::begin(current_curve_heap), std::end(current_curve_heap));
                }

                {                
                    omp_set_lock(&messageLock);
                    progressIndicator(static_cast<Float>(++blockProcessed)/cdfx.size());
                    omp_unset_lock(&messageLock);
                }

            }

            omp_destroy_lock(&messageLock);

            // std::cout << "finish building grid" << std::endl;
            // std::cout << "save aq 3d to txt" << std::endl;
            // std::ofstream outFile_sample("aq_3d.txt");


            for (int i = 0; i < tree_node.size(); ++i) {
                double mean_val = tree_node[i].para[0] + tree_node[i].para[1] * 0.5 +
                    tree_node[i].para[2] * 0.5 + tree_node[i].para[3] * 0.5;
                double value = mean_val * (tree_node[i].right[0] - tree_node[i].left[0]) *
                    (tree_node[i].right[1] - tree_node[i].left[1]) *
                    (tree_node[i].right[2] - tree_node[i].left[2]);
                distrb.append(value);
            }
            std::cout << "cut edge sum: " << non_zero << std::endl;

            distrb.normalize();

        }

        void print() {
            for (int i = 0; i < tree_node.size(); ++i) {
                std::cout << "Tree ID " << i << " : " << tree_node[i].left[0] << " " << tree_node[i].left[1] << " " << tree_node[i].left[2] << " " << tree_node[i].right[0] << " " << tree_node[i].right[1] << " " << tree_node[i].right[2] << std::endl;
                std::cout << "Error " << tree_node[i].dim << " = " << tree_node[i].error << " : y = " << tree_node[i].para[3] << "x3 + " << tree_node[i].para[2] << "x2 + " << tree_node[i].para[1] << "x1 + " << tree_node[i].para[0] << std::endl;
            }
        }


        Vector3 sample(Vector3 rnd, Float & pdf) const {
            int rs = distrb.sampleReuse(rnd[0]);
#ifdef LINEAR
            double rndz = __FZ(tree_node[rs].para, rnd[2]);
            double rndy = __FY(tree_node[rs].para, rndz, rnd[1]);
            double rndx = __FX(tree_node[rs].para, rndz, rndy, rnd[0]);
#else
            double rndz = rnd[2];
            double rndy = rnd[1];
            double rndx = rnd[0];
#endif

            pdf = (tree_node[rs].para[0] + rndx * tree_node[rs].para[1] +
                rndy * tree_node[rs].para[2] + rndz * tree_node[rs].para[3]) / distrb.getSum();

            return Vector3(tree_node[rs].left[0] + rndx * (tree_node[rs].right[0] - tree_node[rs].left[0]),
                tree_node[rs].left[1] + rndy * (tree_node[rs].right[1] - tree_node[rs].left[1]),
                tree_node[rs].left[2] + rndz * (tree_node[rs].right[2] - tree_node[rs].left[2]));
        }

        inline double linear_eval(const Vector2 & data,
            const double & value) const {
            return value * (data[1] - data[0]) + data[0];
        }

        inline double linear_int(const Vector2 & poly,
            const double & value) const {
            return (poly[0] + poly[1]) * value / 2.0;
        }

        double inv_cdf(Vector2 para, double rnd) const {
            // y = para[1]*x + para[0]
            double a = para[1] / 2.0;
            double b = para[0];
            double norm = a + b;

            Vector3 abc(a / norm, b / norm, -rnd);
            if (abs(abc[0]) < 0.00001) {
                return rnd;
            } else {
                return (-abc[1] + sqrt(abc[1] * abc[1] - 4.0 * abc[0] * abc[2])) / (2.0 * abc[0]);
            }
        }

        double __FZ(Vector4 para, double rndz) const {
            double b0 = para[0];
            double b1 = para[1];
            double b2 = para[2];
            double b3 = para[3];
            double val1 = linear_int(Vector2(linear_int(Vector2(b0, b0 + b1), 1.0), linear_int(Vector2(b0 + b2, b0 + b1 + b2), 1.0)), 1.0);
            double val2 = linear_int(Vector2(linear_int(Vector2(b0 + b3, b0 + b1 + b3), 1.0), linear_int(Vector2(b0 + b2 + b3, b0 + b1 + b2 + b3), 1.0)), 1.0);
            Vector2 temp(val1, val2 - val1);
            return inv_cdf(temp, rndz);
        }

        double __FY(Vector4 para, double rndz, double rndy) const {
            double b0 = para[0];
            double b1 = para[1];
            double b2 = para[2];
            double b3 = para[3];
            double val11 = linear_eval(Vector2(b0, b0 + b3), rndz);
            double val12 = linear_eval(Vector2(b0 + b2, b0 + b2 + b3), rndz);
            double val21 = linear_eval(Vector2(b0 + b1, b0 + b1 + b3), rndz);
            double val22 = linear_eval(Vector2(b0 + b1 + b2, b0 + b1 + b2 + b3), rndz);

            double val1 = (val11 + val21) / 2.0;
            double val2 = (val12 - val11 + val22 - val21) / 2.0 + val1;

            Vector2 temp(val1, val2 - val1);
            return inv_cdf(temp, rndy);
        }

        double __FX(Vector4 para, double rndz, double rndy, double rndx) const {
            double b0 = para[0];
            double b1 = para[1];
            double b2 = para[2];
            double b3 = para[3];
            Vector2 para1(linear_eval(Vector2(b0, b0 + b3), rndz), linear_eval(Vector2(b0 + b2, b0 + b2 + b3), rndz));
            double val1 = linear_eval(para1, rndy);
            Vector2 para2(linear_eval(Vector2(b0 + b1, b0 + b1 + b3), rndz), linear_eval(Vector2(b0 + b1 + b2, b0 + b1 + b2 + b3), rndz));
            double val2 = linear_eval(para2, rndy);
            Vector2 temp(val1, val2 - val1);
            return inv_cdf(temp, rndx);
        }

    };
}

#endif //AQ3D_H__