#pragma once
#ifndef CAMERA_H__
#define CAMERA_H__

#include "ptr.h"
#include "fwd.h"
#include "frameAD.h"
#include "pmf.h"
#include "utils.h"
#include "rfilter.h"
#include "sampler.h"
#include "pif.h"
#include <memory>

struct Ray;
struct Shape;
struct Scene;
struct IntersectionAD;

struct CropRectangle {
    int offset_x, offset_y;
    int crop_width, crop_height;
    CropRectangle(): offset_x(0), offset_y(0), crop_width(-1), crop_height(-1) {}
    CropRectangle(int ox, int oy, int width, int height):
                 offset_x(ox), offset_y(oy), crop_width(width), crop_height(height) {}
    bool isValid() const { return crop_width > 0 && crop_height > 0;}
};

struct Camera {
	Camera() {}
    Camera(int width, int height, ptr<float> cam_to_world, ptr<float> cam_to_ndc, float clip_near, int med_id, ptr<float> velocities = ptr<float>(nullptr));
    void setPIF(int _pif, float _tau, float _deltaTau, int _num_bins, float _stepTau);
    inline void setCropRect(int offset_x, int offset_y, int crop_width, int crop_height) { rect = CropRectangle(offset_x, offset_y, crop_width, crop_height); }
    inline int getNumPixels() const { return rect.isValid() ? rect.crop_width*rect.crop_height : width*height; }
    inline int getMedID() const { return med_id; }
    inline int getPixelIndex(const Vector2& pix_uv) const {
        return rect.isValid() ? int(pix_uv.x() - rect.offset_x) + int(pix_uv.y() - rect.offset_y) * rect.crop_width
                              : int(pix_uv.x()) + int(pix_uv.y()) * width;
    }

    void zeroVelocities();
    void initVelocities(const Eigen::Matrix<Float, 3, -1> &dx);
    void initVelocities(const Eigen::Matrix<Float, 3, 1> &dx, int der_index);
    void initVelocities(const Eigen::Matrix<Float, 3, -1> &dx, const Eigen::Matrix<Float, 3, -1> &dw);
    void initVelocities(const Eigen::Matrix<Float, 3, 1> &dx, const Eigen::Matrix<Float, 3, 1> &dw, int der_index);
    void advance(Float stepSize, int derId);

    // set velocity for tau
    void initVelocities(const Eigen::Array<Float, nder, 1> &dt);

    FloatAD evalFilterAD(int pixel_x, int pixel_y, const IntersectionAD &its) const;
    FloatAD evalFilterAD(int pixel_x, int pixel_y, const VectorAD &scatterer) const;
    FloatAD eval1DFilterAD(const FloatAD &x) const;
    Float eval1DFilter(Float x) const;
    Float sampleFromFilter1D(Float rnd) const;
    Vector2 sampleFromFilter(const Vector2 &rnd) const;
    void samplePrimaryRayFromFilter(int pixel_x, int pixel_y, const Vector2 &rnd2, Ray &ray, Ray &dual) const;

    Ray samplePrimaryRay(int pixel_x, int pixel_y, const Vector2 &rnd2) const;
    RayAD samplePrimaryRayAD(int pixel_x, int pixel_y, const Vector2 &rnd2) const;
    Ray samplePrimaryRay(Float x, Float y) const;
    RayAD samplePrimaryRayAD(Float x, Float y) const;
    Float sampleDirect(const Vector& p, Vector2& pixel_uv, Vector& dir) const;
    void sampleDirect(const Vector& p, Matrix2x4& pixel_uvs, Array4& weights, Vector& dir) const;
    FloatAD sampleDirectAD(const VectorAD& p, Vector2AD& pixel_uv, VectorAD& dir) const;
    void sampleDirectAD(const VectorAD& p, Matrix2x4AD& pixel_uvs, Vector4AD& weights, VectorAD& dir) const;

	int width, height;
    CropRectangle rect;
    Matrix4x4 cam_to_world;
    Matrix4x4 world_to_cam;
    Matrix3x3 ndc_to_cam;
    Matrix3x3 cam_to_ndc;
    VectorAD cpos;
    FrameAD cframe;
    std::shared_ptr<ReconstructionFilter> rfilter;

    Float a = 0.5 - Epsilon;

    float clip_near;
    int med_id;
    
    std::shared_ptr<PIF> pif;
};

#endif