#pragma once

#include <core/aabb.h>
#include <core/object.h>
#include <core/pmf.h>
#include <core/ptr.h>
#include <core/ray.h>
#include <core/sampler.h>
#include <core/utils.h>
#include <embree3/rtcore.h>
#include <embree3/rtcore_ray.h>
#include <emitter/envmap.h>
#include <render/bsdf.h>
#include <render/camera.h>
#include <render/embree_scene.h>
#include <render/emitter.h>
#include <render/intersection.h>
#include <render/medium.h>
#include <render/phase.h>
#include <render/common.h>
#include <render/shape.h>

#include <array>
#include <filesystem>
#include <memory>
#include <vector>

struct BoundaryEndpointSamplingRecord {
    bool onSurface;

    Intersection its;
    MediumSamplingRecord mRec;

    Vector p;
    Vector n = Vector::Zero(); // 0 for medium interaction
    Vector sh_n = Vector::Zero();
    Vector wi;
    Spectrum throughput;
    Float pdf;
    int interactions;

    int shape_id = -1;
    int tri_id = -1;
    Vector3 barycentric3;

    int med_id = -1;
    int tet_id = -1;
    Vector4 barycentric4;
};

struct Scene : Object {
    struct State {
        bool use_tetmesh = true;
    };

    using ParamMap = std::unordered_map<std::string, const Object *>;

    Scene() : m_state(ESInit) {}

    Scene(
        const Camera &camera, const std::vector<Shape *> &shapes, const std::vector<BSDF *> &bsdfs,
        const std::vector<Emitter *> &area_lights,
        const std::vector<PhaseFunction *> &phases,
        const std::vector<Medium *> &media,
        bool use_hierarchy = true);

    ~Scene();

    void load_file(const char *file_name, bool auto_configure = true, const Properties &props = Properties());

    void load_envmap(const char *file_name) {
        auto envmap = new EnvironmentMap(file_name, 0);
        envmap->shape_ptr = emitter_list[0]->shape_ptr;
        envmap->shape_id = emitter_list[0]->shape_id;
        emitter_list[0] = envmap;
        // emitter_list.push_back(envmap);
    }

    void merge(const Scene &scene) {
        camera.merge(scene.camera);
        merge_pointer_vector(shape_list, scene.shape_list);
        merge_pointer_vector(bsdf_list, scene.bsdf_list);
        merge_pointer_vector(emitter_list, scene.emitter_list);
        merge_pointer_vector(medium_list, scene.medium_list);
    }

    Scene(const Scene &scene) : camera(scene.camera) {
        shape_list   = deepcopy1(scene.shape_list);
        bsdf_list    = deepcopy(scene.bsdf_list);
        emitter_list = deepcopy(scene.emitter_list);
        phase_list  = deepcopy(scene.phase_list);
        medium_list = deepcopy(scene.medium_list);
        camera = scene.camera;
        // configure();
    }

    Scene(const std::string &file_name, bool auto_configure = true, const Properties &props = Properties()) {
        load_file(file_name.c_str(), auto_configure, props);
    }

    const Scene &operator+=(const Scene &scene) {
        merge(scene);
        return *this;
    }

    const Scene &operator=(const Scene &scene) {
        camera       = scene.camera;
        shape_list   = deepcopy1(scene.shape_list);
        bsdf_list    = deepcopy(scene.bsdf_list);
        emitter_list = deepcopy(scene.emitter_list);
        phase_list  = scene.phase_list;
        medium_list = scene.medium_list;
        configure();
        return *this;
    }

    void assign(const Camera &camera, const std::vector<Shape *> &shapes, const std::vector<BSDF *> &bsdfs,
        const std::vector<Emitter *> &area_lights,
        const std::vector<PhaseFunction *> &phases,
        const std::vector<Medium *> &media);

    void configure(const Properties &props = Properties());
    void configureD(const Scene &scene); // only called on d_scene
    // configure embree scene
    void configure_embree(const Properties &props = Properties());
    // distribution for emitters, used for emitter sampling in nee
    void configure_light_distrib(const Properties &props = Properties());
    // distribution for all shapes, not used
    void configure_shape_distrib(const Properties &props = Properties());
    // distribution for all medium shapes
    void configure_medium_shape_distrib(const Properties &props = Properties());
    // configure distributions
    void configure_distrib(const Properties &props = Properties());
    // build tetmesh
    void configure_tetmesh(const Properties &props = Properties());

    void configure_aabb(const Properties &props = Properties());
    void configure_bounding_mesh(const Properties &props = Properties());

    void initTetmesh();

    Scene clone() {
        Camera cam = camera;
        std::vector<Shape *> shapes = deepcopy1(shape_list);
        std::vector<BSDF *> bsdfs = deepcopy(bsdf_list);
        std::vector<Emitter *> area_lights = deepcopy(emitter_list);
        std::vector<PhaseFunction *> phases      = deepcopy(phase_list);
        std::vector<Medium *>        media       = deepcopy(medium_list);
        return Scene(cam, shapes, bsdfs, area_lights, phases, media
        );
    }

    inline void setZero() {
        camera.setZero();
        common::setZero(shape_list);
        common::setZero(bsdf_list);
        common::setZero(emitter_list);
        common::setZero(phase_list);
        common::setZero(medium_list);
    }

    inline Shape *getShape(int i) const {
        assert(i >= 0 && i < static_cast<int>(shape_list.size()));
        return shape_list[i];
    }

    inline Medium *getMedium(int i) const {
        if (i == -1)
            return nullptr;
        assert(i >= 0 && i < medium_list.size());
        return medium_list[i];
    }

    inline PhaseFunction *getPhase(int i) const {
        assert(i >= 0 && i < phase_list.size());
        return phase_list[i];
    }

    inline BSDF *getBSDF(int i) const {
        assert(i >= 0 && i < static_cast<int>(bsdf_list.size()));
        return bsdf_list[i];
    }

    inline Float getArea() const { return shape_distrb.getSum(); }
    inline Float getMediumArea() const { return medium_shape_distrb.getSum(); }
    // Simple visibility test (IGNORING null interfaces!)
    bool isVisible(
        const Vector &p, bool pOnSurface, const Vector &q, bool qOnSurface) const;

    // Path Tracer
    bool rayIntersect(
        const Ray &ray, bool onSurface, Intersection &its, IntersectionMode mode = ESpatial) const;

    // ray store the the tmin and tmax
    bool rayIntersect(
        const Ray &ray, Intersection &its, IntersectionMode mode) const;

    std::vector<Intersection> rayIntersectAll(const Ray &ray, bool onSurface) const;

    bool trace(
        const Ray &ray, RndSampler *sampler, int max_depth, int med_id,
        BoundaryEndpointSamplingRecord &bERec) const;

    // uniform sampling, with sample reuse
    [[nodiscard]] int sampleEmitter(Float &rnd, Float &pdf) const;
    [[nodiscard]] int samplePointEmitter(Float &rnd, Float &pdf) const;
    [[nodiscard]] int sampleAreaEmitter(Float &rnd, Float &pdf) const;

    [[nodiscard]] Float pdfEmitter(int index) const;

    Spectrum sampleEmitterDirect(const Vector2 &rnd_light, DirectSamplingRecord &dRec) const;

    Spectrum sampleEmitterDirect(
        const Intersection &its, const Vector2 &rnd_light, RndSampler *sampler,
        Vector &wo, Float &pdf) const;

    Spectrum rayIntersectAndLookForEmitter(
        const Ray &ray, bool onSurface, RndSampler *sampler, const Medium *medium,
        Intersection &its, DirectSamplingRecord &dRec) const;

    Spectrum sampleAttenuatedEmitterDirect(
        const Intersection &its, const Vector2 &rnd_light, RndSampler *sampler,
        const Medium *ptr_medium, Vector &wo, /* in local space */
        Float &pdf, bool flag = false) const;

    Spectrum sampleAttenuatedEmitterDirect(
        const Vector &pscatter, const Vector2 &rnd_light, RndSampler *sampler,
        const Medium *ptr_medium, Vector &wo, /* in *world* space */
        Float &pdf) const;

    Spectrum sampleBoundaryAttenuatedEmitterDirect(
        DirectSamplingRecord &dRec, const Vector2 &_rnd_light, RndSampler *sampler,
        const Medium *medium) const;

    Spectrum sampleAttenuatedEmitterDirect(
        DirectSamplingRecord &dRec, const Vector2 &rnd_light, RndSampler *sampler, const Medium *ptr_medium) const;

    Spectrum sampleAttenuatedEmitterDirect(
        DirectSamplingRecord &dRec, const Intersection &its, const Vector2 &rnd_light, RndSampler *sampler,
        const Medium *ptr_medium) const;

    Float evalTransmittance(
        const Ray &ray, bool onSurface, const Medium *ptr_medium, Float remaining, RndSampler *sampler) const;

    Float evalTransmittance(
        const Vector &p1, bool p1OnSurface, const Vector &p2, bool p2OnSurface,
        const Medium *ptr_medium, RndSampler *sampler, bool isRatio = false) const;

    std::pair<Float, Float> evalTransmittanceAndPdf(
        const Vector &p1, bool p1OnSurface, const Vector &p2, bool p2OnSurface,
        const Medium *ptr_medium, RndSampler *sampler) const;

    bool traceForSurface(const Ray &_ray, bool onSurface, Intersection &its) const;

    bool traceForMedium(
        const Ray &_ray, bool onSurface, const Medium *medium, Float targetTransmittance,
        RndSampler *sampler, MediumSamplingRecord &mRec) const;

    // equal distance sampling
    bool traceForMedium2(
        const Ray &_ray, bool onSurface, const Medium *medium, Float tarDist,
        RndSampler *sampler, MediumSamplingRecord &mRec) const;

    Float pdfEmitterSample(const Intersection &its) const;

    Float pdfEmitterDirect(const DirectSamplingRecord &dRec) const;

    Float pdfMediumBoundaryPoint(const Intersection &its) const;

    Spectrum sampleEmitterPosition(
        const Vector2 &rnd_light, Intersection &its, EEmitter type) const;

    Spectrum sampleEmitterPosition(
        const Vector2 &rnd_light, Intersection &its, Float *pdf = nullptr) const;

    Array4 sampleAttenuatedSensorDirect(
        const Intersection &its, RndSampler *sampler, Matrix2x4 &pixel_uvs, Vector &dir) const;

    Array4 sampleAttenuatedSensorDirect(
        const Vector &p, const Medium *ptr_med, RndSampler *sampler, Matrix2x4 &pixel_uvs, Vector &dir) const;

    std::tuple<Float, int, Vector> sampleAttenuatedSensorDirect(
        const Intersection &its, RndSampler *sampler) const;

    std::tuple<Float, int, Vector> sampleAttenuatedSensorDirect(
        const Vector &p, const Medium *ptr_med, RndSampler *sampler) const;

    Float sampleAttenuatedSensorDirect(
        const Intersection &its, RndSampler *sampler, CameraDirectSamplingRecord &cRec) const;

    // return the transmittance
    Float sampleAttenuatedSensorDirect(
        const Vector &p, const Medium *ptr_med, RndSampler *sampler, CameraDirectSamplingRecord &cRec) const;

    Vector2i samplePosition(const Vector2 &rnd2, PositionSamplingRecord &pRec) const;
    Intersection sampleMediumBoundary(const Vector2 &rnd2, PositionSamplingRecord &pRec) const;
    // x, n, J
    std::tuple<Vector, Vector, Float> getPoint(const Intersection &its) const;

    void getPoint(const Intersection &its, Intersection &itsAD) const;

    void sampleEdgePoint(
        const Float &_rnd, const DiscreteDistribution &edge_dist, const std::vector<Vector2i> &edge_ind,
        EdgeSamplingRecord &eRec, EdgePrimarySampleRecord *ePSRec = nullptr) const;

#ifdef FORWARD
    int getShapeRequiresGrad() const {
        for (int i = 0; i < shape_list.size(); i++)
            if (shape_list[i]->requires_grad)
                return i;
        return -1;
    }

    Float getParameter() const {
        Float param = 0.;
        for (auto &shape : shape_list) {
            if (shape->requires_grad)
                param += shape->param;
        }
        return param;
    }

    void zeroParameter() {
        for (auto &shape : shape_list) {
            shape->param = 0.;
        }
    }

#endif

    // Scene objects
    Camera camera; // active camera
    std::vector<Camera *> cameras;
    std::vector<Shape *> shape_list;
    std::vector<BSDF *> bsdf_list;
    std::vector<Emitter *> emitter_list;
    std::vector<Emitter *> point_emitter_list;
    std::vector<Emitter *> area_emitter_list;
    std::vector<PhaseFunction *> phase_list;
    std::vector<Medium *> medium_list;

    AABB m_aabb;

    // Embree handles
    RTCDevice embree_device = nullptr;
    RTCScene embree_scene = nullptr;
    EmbreeScene m_embree_scene;

    // Point sampling
    DiscreteDistribution shape_distrb;
    // NOTE: the size of medium_shape_distrb is the same as the size of shape_list
    DiscreteDistribution medium_shape_distrb;

    // Loader
    EState m_state = EState::ESInit;
    ParamMap m_param_map;
    Properties m_properties;
    State state;

    // Rendering
    RenderOptions m_render_options;
    PSDR_DECLARE_CLASS(Scene)
};

struct SceneAD {
    Scene val;                                                             // val
    Scene der;                                                             // der
    GradientManager<Scene> gm;                                             // multi-threaded gradient manager that stores a reference to d_scene
    SceneAD(const Scene &s) : val(s), der(s), gm(der, omp_get_num_procs()) // copy constructor
    {
        val.configure(s.m_properties);
        der.setZero();
        gm.setZero();
    }
    Scene &getDer() {
        return gm.get(omp_get_thread_num());
    }
    void zeroGrad() {
        der.setZero();
        gm.setZero();
    }
};

#include <core/template.h>
using SceneAD1 = TypeAD<Scene, true>;

/* normal related */
#ifdef NORMAL_PREPROCESS
void d_precompute_normal(const Scene &scene, Scene &d_scene);
#endif
