#include "camera.h"
#include "ray.h"
#include "shape.h"
#include "camera.h"
#include "scene.h"
#include <math.h>
// #define COS_FILTER

#define DEBUG(x)                                                                    \
	do                                                                          \
	{                                                                           \
		std::cerr << std::setprecision(15) << #x << ": " << x << std::endl; \
	} while (0)

struct CosFilter
{
    Float a;
    Float l, r;
    Float norm;
    Float pi = M_PI;
    CosFilter(Float a) : a(a)
    {
        l = 0.5 - 1 / (2 * a);
        r = 0.5 + 1 / (2 * a);
        norm = FF(r) - FF(l);
    }

    Float ff(Float x)
    {
        return 1 + cos(pi * a * (2 * x - 1));
    }

    FloatAD ffAD(const FloatAD &x)
    {
        return 1 + (pi * a * (2 * x - 1)).cos();
    }

    Float FF(Float x)
    {
        return x - sin(a * pi * (1 - 2 * x)) / (2 * a * pi);
    }

    Float f(Float x)
    {
        Float ret;
        if (l < x && x < r)
            ret = ff(x);
        else
            ret = 0;
        return ret / norm;
    }

    FloatAD fAD(const FloatAD &x)
    {
        FloatAD ret;
        if (l < x && x < r)
            ret = ffAD(x);
        else
            ret = 0;
        return ret / norm;
    }

    Float F(Float x)
    {
        Float ret;
        if (x < l){
            ret = 0;
        }
        else if (x < r)
        {
            ret = (FF(x)-FF(l)) / norm;
        }
        else
            ret = 1;
        return ret;
    }

    Float invF(Float y)
    {
        Float a = l, b = r;
        auto func = [this, y](Float x) {
            return this->F(x) - y;
        };

        if (func(a) * func(b) > Epsilon)
        {
            std::cerr << "You have not assumed right a and b\n";
            assert(false);
        }

        Float c = a;
        while ((b - a) >= Epsilon)
        {
            // Find middle point
            c = (a + b) / 2;

            // Check if middle point is root
            if (func(c) == 0.0)
                break;

            // Decide the side to repeat the steps
            else if (func(c) * func(a) < 0)
                b = c;
            else
                a = c;
        }
        return c;
    }
};

static CosFilter filter(0.5);

Camera::Camera(int width, int height, ptr<float> cam_to_world, ptr<float> cam_to_ndc, float clip_near, int med_id, ptr<float> velocities) {
    this->width = width;
    this->height = height;
    this->clip_near = clip_near;
    this->med_id = med_id;
    this->cam_to_world << cam_to_world[0],  cam_to_world[1],  cam_to_world[2],  cam_to_world[3],
                          cam_to_world[4],  cam_to_world[5],  cam_to_world[6],  cam_to_world[7],
                          cam_to_world[8],  cam_to_world[9],  cam_to_world[10], cam_to_world[11],
                          cam_to_world[12], cam_to_world[13], cam_to_world[14], cam_to_world[15];
    this->world_to_cam = this->cam_to_world.inverse();
    this->cam_to_ndc << cam_to_ndc[0], cam_to_ndc[1], cam_to_ndc[2],
                        cam_to_ndc[3], cam_to_ndc[4], cam_to_ndc[5],
                        cam_to_ndc[6], cam_to_ndc[7], cam_to_ndc[8];
    assert(this->cam_to_ndc(0,0) == this->cam_to_ndc(1,1));
    this->ndc_to_cam = this->cam_to_ndc.inverse();

    cpos.val     = this->cam_to_world.block<3,1>(0,3);
    cframe.s.val = this->cam_to_world.block<3,1>(0,0);
    cframe.t.val = this->cam_to_world.block<3,1>(0,1);
    cframe.n.val = this->cam_to_world.block<3,1>(0,2);

    if (!velocities.is_null()) {
        int nrows = 3;
        int ncols = nder;
        Eigen::MatrixXf dx(nrows, ncols);
        dx = Eigen::Map<Eigen::MatrixXf>(velocities.get(), dx.rows(), dx.cols());
        Eigen::MatrixXf dw(nrows, ncols);
        dw = Eigen::Map<Eigen::MatrixXf>(velocities.get() + 3*nder, dw.rows(), dw.cols());
        initVelocities(dx.cast<Float>(), dw.cast<Float>());
    }

    rfilter = std::make_shared<TentFilter>(0.5);
}

// filter is a piecewise linear function
Float Camera::eval1DFilter(Float x) const
{
    return rfilter->eval(x);
}

FloatAD Camera::eval1DFilterAD(const FloatAD &x) const
{
    return rfilter->evalAD(x);
}

Vector2 Camera::getPixelUV(const Ray &ray) const
{
    assert((ray.org - cpos).norm() < Epsilon);
    Vector local = cframe.toFrame().toLocal(ray.dir);
    Vector ndc = cam_to_ndc * local;
    Float t = 1.0 / local.z();
    ndc *= t;
    auto aspect_ratio = Float(width) / Float(height);
    Vector2 ret;
    ret.x() = ((ndc.x() / 2.0 + 0.5)) * width;
    ret.y() = (ndc.y() * aspect_ratio / -2.0 + 0.5) * height;
    return ret;
}

Vector2AD Camera::getPixelUVAD(const RayAD &ray) const
{
    assert((ray.org - cpos).norm().val < Epsilon);
    VectorAD local = cframe.toLocal(ray.dir);
    VectorAD ndc = Matrix3x3AD(cam_to_ndc) * local;
    FloatAD t = 1.0 / local.z();
    ndc *= t;
    auto aspect_ratio = Float(width) / Float(height);
    Vector2AD ret;
    ret.x() = ((ndc.x() / 2.0 + 0.5)) * width;
    ret.y() = (ndc.y() * aspect_ratio / -2.0 + 0.5) * height;
    return ret;
}

FloatAD Camera::evalFilterAD(int pixel_x, int pixel_y, const IntersectionAD &its) const
{
    VectorAD dir = its.p - cpos;
    FloatAD dist = dir.norm();
    dir /= dist;
    VectorAD local = cframe.toLocal(dir);
    VectorAD ndc = Matrix3x3AD(cam_to_ndc) * local;
    FloatAD t = 1.0 / local.z();
    ndc *= t;
    auto aspect_ratio = Float(width) / Float(height);
    FloatAD x_screen = ((ndc.x() / 2.0 + 0.5)) * width;
    FloatAD y_screen = (ndc.y() * aspect_ratio / -2.0 + 0.5) * height;
    FloatAD x_offset = x_screen - Float(pixel_x);
    FloatAD y_offset = y_screen - Float(pixel_y);
    FloatAD fx = eval1DFilterAD(x_offset);
    FloatAD fy = eval1DFilterAD(y_offset);
    FloatAD f = fx * fy;
    FloatAD cosy = local.z();
    FloatAD cosx = its.wi.z();
    FloatAD G = cosx / (cosy * cosy * cosy * dist.square());
    Float pdf = f.val * G.val;
    FloatAD throughput = f * G / pdf;
    return throughput;
}
FloatAD Camera::evalFilterAD(int pixel_x, int pixel_y, const VectorAD &scatterer) const
{
    VectorAD dir = scatterer - cpos;
    FloatAD dist = dir.norm();
    dir /= dist;
    VectorAD local = cframe.toLocal(dir);
    VectorAD ndc = Matrix3x3AD(cam_to_ndc) * local;
    FloatAD t = 1.0 / local.z();
    ndc *= t;
    auto aspect_ratio = Float(width) / Float(height);
    FloatAD x_screen = ((ndc.x() / 2.0 + 0.5)) * width;
    FloatAD y_screen = (ndc.y() * aspect_ratio / -2.0 + 0.5) * height;
    FloatAD x_offset = x_screen - Float(pixel_x);
    FloatAD y_offset = y_screen - Float(pixel_y);
    FloatAD fx = eval1DFilterAD(x_offset);
    FloatAD fy = eval1DFilterAD(y_offset);
    FloatAD f = fx * fy;
    FloatAD cosy = local.z();
    FloatAD G = 1 / (cosy * cosy * cosy * dist.square());
    Float pdf = f.val * G.val;
    FloatAD throughput = f * G / pdf;
    return throughput;
}

Float Camera::sampleFromFilter1D(Float rnd) const
{
    Float pdf;
    return rfilter->sample(rnd, pdf);
}

Vector2 Camera::sampleFromFilter(const Vector2 &rnd) const
{
    Vector2 ret;
    ret.x() = sampleFromFilter1D(rnd.x());
    ret.y() = sampleFromFilter1D(rnd.y());
    return ret;
}

void Camera::samplePrimaryRayFromFilter(int pixel_x, int pixel_y, const Vector2 &rnd2, Ray &ray, Ray &dual) const
{
    Float x = sampleFromFilter1D(rnd2.x());
    Float y = sampleFromFilter1D(rnd2.y());

    Float x_dual = 1.0 - x;
    Float y_dual = 1.0 - y;

    ray = samplePrimaryRay(pixel_x + x, pixel_y + y);
    dual = samplePrimaryRay(pixel_x + x_dual, pixel_y + y_dual);
}

Ray Camera::samplePrimaryRay(int pixel_x, int pixel_y, const Vector2 &rnd2) const {
	Vector2 screen_pos = Vector2( (pixel_x + rnd2.x())/Float(width),
                                  (pixel_y + rnd2.y())/Float(height));
    auto aspect_ratio = Float(width) / Float(height);
    auto ndc = Vector3((screen_pos(0) - 0.5f) * 2.f,
                       (screen_pos(1) - 0.5f) * (-2.f) / aspect_ratio,
                        Float(1));
    Vector3 dir = ndc_to_cam * ndc;
    dir.normalize();
    return Ray{xfm_point(cam_to_world, Vector3::Zero()), xfm_vector(cam_to_world, dir)};
}

Ray Camera::samplePrimaryRay(Float x, Float y) const {
    x /= width;
    y /= height;
    auto aspect_ratio = Float(width) / Float(height);
    auto ndc = Vector3((x - 0.5f)*2.f, (y - 0.5f)*(-2.f)/aspect_ratio, 1.0f);
    Vector3 dir = ndc_to_cam * ndc;
    dir.normalize();
    return Ray{xfm_point(cam_to_world, Vector3::Zero()), xfm_vector(cam_to_world, dir)};
}

RayAD Camera::samplePrimaryRayAD(int pixel_x, int pixel_y, const Vector2 &rnd2) const {
    Vector2 screen_pos = Vector2( (pixel_x + rnd2.x())/Float(width),
                                  (pixel_y + rnd2.y())/Float(height));
    auto aspect_ratio = Float(width) / Float(height);
    auto ndc = Vector3((screen_pos(0) - 0.5f) * 2.f,
                       (screen_pos(1) - 0.5f) * (-2.f) / aspect_ratio,
                        Float(1));
    Vector3 dir = ndc_to_cam * ndc;
    dir.normalize();

    return RayAD(cpos, cframe.toWorld(VectorAD(dir)));
}

RayAD Camera::samplePrimaryRayAD(Float x, Float y) const {
    x /= width;
    y /= height;
    auto aspect_ratio = Float(width) / Float(height);
    auto ndc = Vector3((x - 0.5f)*2.f, (y - 0.5f)*(-2.f)/aspect_ratio, 1.0f);
    Vector3 dir = ndc_to_cam * ndc;
    dir.normalize();
    return RayAD(cpos, cframe.toWorld(VectorAD(dir)));
}

void Camera::zeroVelocities() {
    cframe.s.zeroGrad();
    cframe.t.zeroGrad();
    cframe.n.zeroGrad();
    cpos.zeroGrad();
}

void Camera::initVelocities(const Eigen::Matrix<Float, 3, -1> &dx) {
    assert(dx.cols() == nder);
    for (int i = 0; i < nder; i++)
        cpos.grad(i) = dx.col(i);
}

void Camera::initVelocities(const Eigen::Matrix<Float, 3, 1> &dx, int der_index) {
    assert(der_index >= 0 && der_index < nder);
    cpos.grad(der_index) = dx;
}

void Camera::initVelocities(const Eigen::Matrix<Float, 3, -1> &dx, const Eigen::Matrix<Float, 3, -1> &dw) {
    assert(dx.cols() == nder && dw.cols() == nder);
    initVelocities(dx);
    for (int i = 0; i < nder; i++) {
        cframe.s.grad(i) = dw.col(i).cross(cframe.s.val);
        cframe.t.grad(i) = dw.col(i).cross(cframe.t.val);
        cframe.n.grad(i) = dw.col(i).cross(cframe.n.val);
    }
}

void Camera::initVelocities(const Eigen::Matrix<Float, 3, 1> &dx, const Eigen::Matrix<Float, 3, 1> &dw, int der_index) {
    assert(der_index >= 0 && der_index < nder);
    initVelocities(dx, der_index);
    cframe.s.grad(der_index) = dw.cross(cframe.s.val);
    cframe.t.grad(der_index) = dw.cross(cframe.t.val);
    cframe.n.grad(der_index) = dw.cross(cframe.n.val);
}

void Camera::advance(Float stepSize, int derId) {
    assert(derId >= 0 && derId < nder);
    cpos.advance(stepSize, derId);
    cframe.s.advance(stepSize, derId);
    cframe.t.advance(stepSize, derId);
    cframe.n.advance(stepSize, derId);
}

Float Camera::sampleDirect(const Vector& p, Vector2& pixel_uv, Vector& dir) const {
    Vector refP = xfm_point(world_to_cam, p);
    if (refP.z() < clip_near)
        return 0.0;
    auto fov_factor = cam_to_ndc(0, 0);
    auto aspect_ratio = Float(width) / Float(height);
    //! might problematic. area = 2 x 2 x ndc_to_cam x ndc_to_cam * aspect_ratio
    Float inv_area = 0.25 * fov_factor*fov_factor*aspect_ratio;

    int xmin = 0, xmax = width;
    int ymin = 0, ymax = height;

    if (rect.isValid()) {
        inv_area *= (Float)width/(Float)rect.crop_width *  (Float)height/rect.crop_height;
        xmin = rect.offset_x; xmax = rect.offset_x + rect.crop_width;
        ymin = rect.offset_y; ymax = rect.offset_y + rect.crop_height;
    }


    Vector pos_camera = cam_to_ndc * refP;
    pos_camera.x() /= pos_camera.z();
    pos_camera.y() /= pos_camera.z();
    Vector2 screen_pos = Vector2( (pos_camera.x() * 0.5f + 0.5f) * width,
                                  (-pos_camera.y() * 0.5f * aspect_ratio + 0.5f) * height);
    if (screen_pos.x() >= xmin && screen_pos.x() <=xmax &&
        screen_pos.y() >= ymin && screen_pos.y() <=ymax)
    {
        pixel_uv.x() = screen_pos.x() - xmin;
        pixel_uv.y() = screen_pos.y() - ymin;
        Float dist = refP.norm(), inv_dist = 1.0f/dist;
        refP *= inv_dist;
        Float inv_cosTheta = 1.0f/refP.z();
        dir = (cpos.val - p) * inv_dist;
        Float inv_pixel_area = inv_area * (xmax-xmin) * (ymax-ymin);
        return inv_dist * inv_dist * inv_cosTheta * inv_cosTheta * inv_cosTheta * inv_pixel_area;
    } else {
        return 0.0;
    }
}


void Camera::sampleDirect(const Vector& p, Matrix2x4& pixel_uvs, Array4& weights, Vector& dir) const {
    weights = Array4(0.0);
    Vector refP = xfm_point(world_to_cam, p);
    if (refP.z() < clip_near)
        return;
    auto fov_factor = cam_to_ndc(0, 0);
    auto aspect_ratio = Float(width) / Float(height);
    Float inv_area = 0.25 * fov_factor*fov_factor*aspect_ratio;

    int xmin = 0, xmax = width;
    int ymin = 0, ymax = height;

    if (rect.isValid()) {
        inv_area *= (Float)width/(Float)rect.crop_width *  (Float)height/rect.crop_height;
        xmin = rect.offset_x; xmax = rect.offset_x + rect.crop_width;
        ymin = rect.offset_y; ymax = rect.offset_y + rect.crop_height;
    }


    Vector pos_camera = cam_to_ndc * refP;
    pos_camera.x() /= pos_camera.z();
    pos_camera.y() /= pos_camera.z();
    Vector2 screen_pos = Vector2( (pos_camera.x() * 0.5f + 0.5f) * width,
                                  (-pos_camera.y() * 0.5f * aspect_ratio + 0.5f) * height);
    if (screen_pos.x() >= xmin && screen_pos.x() <=xmax &&
        screen_pos.y() >= ymin && screen_pos.y() <=ymax)
    {
        Vector2 pixel_uv;
        pixel_uv.x() = screen_pos.x() - xmin;
        pixel_uv.y() = screen_pos.y() - ymin;
        Float dist = refP.norm(), inv_dist = 1.0f/dist;
        refP *= inv_dist;
        Float inv_cosTheta = 1.0f/refP.z();
        dir = (cpos.val - p) * inv_dist;
        Float inv_pixel_area = inv_area * (xmax-xmin) * (ymax-ymin);
        Float val0 = inv_dist * inv_dist * inv_cosTheta * inv_cosTheta * inv_cosTheta * inv_pixel_area;

        // compute all involved neighboring pixels
        Vector2i pixel_index( (int)pixel_uv.x(), (int)pixel_uv.y() );
        Float x_offset = pixel_uv.x() - pixel_index.x();
        Float y_offset = pixel_uv.y() - pixel_index.y();
        int offsetX_min = (x_offset < a     && pixel_index.x() > 0)      ? -1 : 0;
        int offsetX_max = (x_offset > 1.0-a && pixel_index.x() < xmax-1) ? 1 : 0;
        int offsetY_min = (y_offset < a     && pixel_index.y() > 0)      ? -1 : 0;
        int offsetY_max = (y_offset > 1.0-a && pixel_index.y() < ymax-1) ? 1 : 0;

        int index = 0;
        for (int ox = offsetX_min; ox <= offsetX_max; ox++)
            for (int oy = offsetY_min; oy <= offsetY_max; oy++) {
                pixel_uvs(0, index) = pixel_uv.x() + ox;
                pixel_uvs(1, index) = pixel_uv.y() + oy;
                weights(index) = val0 * eval1DFilter(x_offset-ox) * eval1DFilter(y_offset-oy);
                index++;
            }
        return;
    } else {
        return;
    }
}

FloatAD Camera::sampleDirectAD(const VectorAD &p, Vector2AD &pixel_uv, VectorAD &dir) const
{
    VectorAD dir_world = p -cpos;
    VectorAD refP = cframe.toLocal(dir_world);
    if (refP.z() < clip_near)
        return 0.0;
    auto fov_factor = cam_to_ndc(0, 0);
    auto aspect_ratio = Float(width) / Float(height);
    Float inv_area = 0.25 * fov_factor * fov_factor * aspect_ratio;

    int xmin = 0, xmax = width;
    int ymin = 0, ymax = height;

    if (rect.isValid())
    {
        inv_area *= (Float)width / (Float)rect.crop_width * (Float)height / rect.crop_height;
        xmin = rect.offset_x;
        xmax = rect.offset_x + rect.crop_width;
        ymin = rect.offset_y;
        ymax = rect.offset_y + rect.crop_height;
    }

    VectorAD pos_camera = Matrix3x3AD(cam_to_ndc) * refP;
    pos_camera.x() /= pos_camera.z();
    pos_camera.y() /= pos_camera.z();
    Vector2AD screen_pos = Vector2AD((pos_camera.x() * 0.5f + 0.5f) * width,
                                 (-pos_camera.y() * 0.5f * aspect_ratio + 0.5f) * height);
    if (screen_pos.x().val >= xmin && screen_pos.x().val <= xmax &&
        screen_pos.y().val >= ymin && screen_pos.y().val <= ymax)
    {
        pixel_uv.x() = screen_pos.x() - xmin;
        pixel_uv.y() = screen_pos.y() - ymin;
        FloatAD dist = refP.norm(), inv_dist = 1.0f / dist;
        refP *= inv_dist;
        FloatAD inv_cosTheta = 1.0f / refP.z();
        dir = (cpos - p) * inv_dist;
        Float inv_pixel_area = inv_area * (xmax - xmin) * (ymax - ymin);
        return inv_dist * inv_dist * inv_cosTheta * inv_cosTheta * inv_cosTheta * inv_pixel_area;
    }
    else
    {
        return 0.0;
    }
}

void Camera::sampleDirectAD(const VectorAD &p, Matrix2x4AD &pixel_uvs, Vector4AD &weights, VectorAD &dir) const
{
    weights = Vector4AD();
    VectorAD dir_world = p - cpos;
    VectorAD refP = cframe.toLocal(dir_world);
    if (refP.z() < clip_near)
        return;
    auto fov_factor = cam_to_ndc(0, 0);
    auto aspect_ratio = Float(width) / Float(height);
    Float inv_area = 0.25 * fov_factor*fov_factor*aspect_ratio;

    int xmin = 0, xmax = width;
    int ymin = 0, ymax = height;

    if (rect.isValid()) {
        inv_area *= (Float)width/(Float)rect.crop_width *  (Float)height/rect.crop_height;
        xmin = rect.offset_x; xmax = rect.offset_x + rect.crop_width;
        ymin = rect.offset_y; ymax = rect.offset_y + rect.crop_height;
    }


    VectorAD pos_camera = Matrix3x3AD(cam_to_ndc) * refP;
    pos_camera.x() /= pos_camera.z();
    pos_camera.y() /= pos_camera.z();
    Vector2AD screen_pos = Vector2AD( (pos_camera.x() * 0.5f + 0.5f) * width,
                                  (-pos_camera.y() * 0.5f * aspect_ratio + 0.5f) * height);
    if (screen_pos.x().val >= xmin && screen_pos.x().val <=xmax &&
        screen_pos.y().val >= ymin && screen_pos.y().val <=ymax)
    {
        Vector2AD pixel_uv;
        pixel_uv.x() = screen_pos.x() - xmin;
        pixel_uv.y() = screen_pos.y() - ymin;
        FloatAD dist = refP.norm(), inv_dist = 1.0f/dist;
        refP *= inv_dist;
        FloatAD inv_cosTheta = 1.0f/refP.z();
        dir = (cpos - p) * inv_dist;
        Float inv_pixel_area = inv_area * (xmax-xmin) * (ymax-ymin);
        FloatAD val0 = inv_dist * inv_dist * inv_cosTheta * inv_cosTheta * inv_cosTheta * inv_pixel_area;

        // compute all involved neighboring pixels
        Vector2i pixel_index( (int)pixel_uv.x().val, (int)pixel_uv.y().val );
        FloatAD x_offset = pixel_uv.x() - pixel_index.x();
        FloatAD y_offset = pixel_uv.y() - pixel_index.y();
        int offsetX_min = (x_offset < a     && pixel_index.x() > 0)      ? -1 : 0;
        int offsetX_max = (x_offset > 1.0-a && pixel_index.x() < xmax-1) ? 1 : 0;
        int offsetY_min = (y_offset < a     && pixel_index.y() > 0)      ? -1 : 0;
        int offsetY_max = (y_offset > 1.0-a && pixel_index.y() < ymax-1) ? 1 : 0;

        int index = 0;
        for (int ox = offsetX_min; ox <= offsetX_max; ox++)
            for (int oy = offsetY_min; oy <= offsetY_max; oy++) {
                pixel_uvs(0, index) = pixel_uv.x() + ox;
                pixel_uvs(1, index) = pixel_uv.y() + oy;
                weights(index) = val0 * eval1DFilterAD(x_offset-ox) * eval1DFilterAD(y_offset-oy);
                index++;
            }
        return;
    } else {
        return;
    }
}
