#pragma once
#ifndef CAMERA_H__
#define CAMERA_H__

#include <core/fwd.h>
#include <core/properties.h>
#include <core/ptr.h>
#include <core/object.h>
#include <core/pmf.h>
#include <core/utils.h>
#include <core/frame.h>
#include <render/rfilter.h>
#include <render/common.h>
#include <memory>

struct Ray;
struct Shape;
struct Scene;
struct Intersection;
struct RndSampler;

struct CropRectangle
{
    int offset_x, offset_y;
    int crop_width, crop_height;
    CropRectangle() : offset_x(0), offset_y(0), crop_width(-1), crop_height(-1) {}
    CropRectangle(int ox, int oy, int width, int height) : offset_x(ox), offset_y(oy), crop_width(width), crop_height(height) {}
    bool isValid() const { return crop_width > 0 && crop_height > 0; }
};

struct CameraDirectSamplingRecord
{
    Vector2 pixel_uv;
    Vector2i pixel_idx;
    Vector dir;
    Float baseVal;
};

struct Camera : Object
{
    Camera() {}

    Camera(const Properties &props);

    Camera(int width, int height, Float fov, const Matrix4x4 &to_world, int type);
    Camera(int width, int height, Float fov, const Vector &origin, const Vector &target, const Vector &up, int type);

    Camera(int width, int height, const Matrix4x4 &cam_to_world, const Matrix3x3 &cam_to_ndc, int med_id, int type);
    Camera(int width, int height, Float fov, const Matrix4x4 &cam_to_world, ReconstructionFilter *filter);

    Camera(const Camera &camera);
    Camera &operator=(const Camera &camera);

    inline void copyFrom(const Camera &camera);
    inline void initRfilter(const Properties &props);
    inline void initRfilter(int type);

    ~Camera()
    {
        if (rfilter)
            delete rfilter;
    }

    void merge(const Camera &camera);
    void setZero();

    void configure();

    inline void setCropRect(int offset_x, int offset_y, int crop_width, int crop_height)
    {
        rect = CropRectangle(offset_x, offset_y, crop_width, crop_height);
    }

    inline int getNumPixels() const { return rect.isValid() ? rect.crop_width * rect.crop_height : width * height; }
    inline int getMedID() const { return med_id; }
    inline int getPixelIndex(const Vector2 &pix_uv) const
    {
        return rect.isValid() ? int(pix_uv.x() - rect.offset_x) + int(pix_uv.y() - rect.offset_y) * rect.crop_width
                              : int(pix_uv.x()) + int(pix_uv.y()) * width;
    }
    inline Array2i getOffset() const
    {
        return rect.isValid() ? Array2i(rect.offset_x, rect.offset_y) : Array2i(0, 0);
    }

    inline Array2i getCropSize() const
    {
        return rect.isValid() ? Array2i(rect.crop_width, rect.crop_height) : Array2i(width, height);
    }

    inline int getPixelIndex1D(const Vector2i &pix_idx) const
    {
        return rect.isValid() ? (pix_idx.x() - rect.offset_x) + (pix_idx.y() - rect.offset_y) * rect.crop_width
                              : pix_idx.x() +  pix_idx.y() * width;
    }

    inline std::tuple<int, int, int, int> getPixelAABB() const {
        if (rect.isValid())
            return {rect.offset_x, rect.offset_x + rect.crop_width, rect.offset_y, rect.offset_y + rect.crop_height};
        else
            return {0, width, 0, height};
    }

    void advance(Float stepSize, int derId);

    // Given a camera ray, return the pixel indices whose filters overlap the ray. Store the indices in Vector4i, -1 is invalid.
    // void sampleAttenuatedDirect(const Ray &ray, Vector4i &pixel_indices) const;
    Vector2 getPixelUV(const Vector &dir) const;
    Vector2 getPixelUV(const Ray &ray) const;

    inline Float evalFilter(int ix, int iy, Float x, Float y) const;
    inline Float _eval(int ix, int iy, const Vector &p, const Vector *n, bool eval_filter) const;

    Float evalFilter(int ix, int iy, const Intersection &its) const;
    Float evalFilter(int ix, int iy, const Vector &p, const Vector &n) const;
    Float evalFilter(int ix, int iy, const Vector &p) const;
    Float eval(int ix, int iy, const Vector &p, const Vector &n) const;
    Float eval(int ix, int iy, const Vector &p) const;
    Float evalDir(int ix, int iy, const Vector &dir) const;
    // only evaluate the filter value, no geometric term accounted
    Float evalFilterDir(int ix, int iy, const Vector &dir) const;
    // given a direction, evaluate the pmf of choosing the pixel
    Float pdfPixel(int ix, int iy, const Vector &dir) const; 

    inline void _samplePrimaryRayFromFilter(int pixel_x, int pixel_y, const Vector2 &rnd2, Float &x, Float &y) const;
    void samplePrimaryRayFromFilter(int pixel_x, int pixel_y, const Vector2 &rnd2, Ray &ray) const;
    void samplePrimaryRayFromFilter(int pixel_x, int pixel_y, const Vector2 &rnd2, Ray &ray, Ray &dual) const;
    void samplePrimaryRayFromFilter(int pixel_x, int pixel_y, const Vector2 &rnd2, Ray *rays) const;

    Ray sampleDualRay(const Array2i &pixel_idx, Ray &ray) const;
    Ray sampleDualRay(const Vector2i &pixel_idx, const Vector2 &pixel_uv) const;
    Ray samplePrimaryRay(int pixel_x, int pixel_y, const Vector2 &rnd2) const;
    Ray samplePrimaryRay(Float x, Float y) const;

    std::tuple<Ray, Vector, Vector, Vector> samplePrimaryBoundaryRay(const Array2i &pixel_idx, Float rnd) const;

    Float sampleDirect(const Vector &p, Vector2 &pixel_uv, Vector &dir) const;
    void sampleDirect(const Vector &p, Matrix2x4 &pixel_uvs, Array4 &weights, Vector &dir) const;

    bool sampleDirect(const Vector &p, CameraDirectSamplingRecord &cRec) const;
    std::pair<int, Float> sampleDirectPixel(const CameraDirectSamplingRecord &cRec, Float rnd) const;
    void accumulateDirect(const CameraDirectSamplingRecord &cRec, const Spectrum &value, RndSampler *sampler, Spectrum *buffer) const;

    Float sampleDirect(const Vector &p, Float rnd, int &pixel_idx, Vector &dir) const;

    Float geometric(const Vector &p, const Vector &n) const;
    Float geometric(const Vector &p) const;
    // convert pdf on image plane to solid angle
    Float convertPdf(const Vector &dir) const;

    Vector getUp() const { return cframe.n; }
    Vector getRight() const { return cframe.s; }

    Vector toLocal(const Vector &dir) const { return xfm_vector(world_to_cam, dir); }
    Vector toWorld(const Vector &dir) const { return xfm_vector(cam_to_world, dir); }
    Vector camToWorld(const Vector &dir) const { return xfm_point(cam_to_world, Vector3::Zero()) + xfm_vector(cam_to_world, dir); }
    Vector worldToCam(const Vector &dir) const { return xfm_point(world_to_cam, Vector3::Zero()) + xfm_vector(world_to_cam, dir); }

    int width, height;
    Float m_fov;
    CropRectangle rect;
    Matrix4x4 cam_to_world, world_to_cam;
    Matrix3x3 ndc_to_cam, cam_to_ndc;

    Vector cpos;    // Center of projection
    Frame cframe;   // (x, y): Horizontal and vertical axes of the image plane;
                    // (z): axis of projection

    int med_id = -1;

    ReconstructionFilter *rfilter = nullptr;
    Eigen::MatrixXd sigX, sigY;

    Float a = 0.5 - Epsilon; // To be removed

    EState m_state = EState::ESInit;

    PSDR_DECLARE_CLASS(Camera)
};

#endif