#pragma once
#ifndef SHAPE_H__
#define SHAPE_H__

#include "fwd.h"
#include "ptr.h"
#include "edge_manager.h"
#include "pmf.h"
#include <cmath>
#include <string>

struct Ray;
struct RayAD;
struct Intersection;
struct IntersectionAD;
struct Edge;

struct PositionSamplingRecord {
    Vector p, n;
    Vector2 uv;
    Float pdf;
};

struct PositionSamplingRecordAD {
    VectorAD p, n;
    FloatAD J;
    Float pdf;
};

struct DirectSamplingRecord : public PositionSamplingRecord {
    Vector ref;
    Vector dir;
    Float dist;
    DirectSamplingRecord(const Vector &ref) : ref(ref) {}
};

struct DirectSamplingRecordAD : public PositionSamplingRecordAD {
    VectorAD ref;
    VectorAD dir;
    FloatAD dist;
    FloatAD G;
    DirectSamplingRecordAD(const VectorAD &ref) : ref(ref) {}
};

struct Shape {
    Shape() {}
    Shape(ptr<float>, ptr<int>, ptr<float>, ptr<float>, int, int, int, int, int, int, ptr<float> velocities = ptr<float>(nullptr));

    void zeroVelocities();
    void initVelocities(const Eigen::Matrix<Float, -1, -1> &dx);
    void initVelocities(const Eigen::Matrix<Float, -1, -1> &dx, int der_index);
#ifndef SHAPE_COMPUTE_VTX_NORMAL
    void initVelocities(const Eigen::Matrix<Float, -1, -1> &dx, const Eigen::Matrix<Float, -1, -1> &dn);
    void initVelocities(const Eigen::Matrix<Float, -1, -1> &dx, const Eigen::Matrix<Float, -1, -1> &dn, int der_index);
#endif

    void advance(Float stepSize, int derId = 0);
    void computeFaceNormals();
#ifdef SHAPE_COMPUTE_VTX_NORMAL
    void computeVertexNormals();
#endif

    inline bool isMediumTransition() const { return med_ext_id >= 0 || med_int_id >= 0;}
    inline bool hasUVs() const { return uvs.size() != 0; }
    inline bool hasNormals() const { return normals.size() != 0; }
    inline bool isEmitter() const { return light_id >= 0; }
    inline const Vector3& getVertex(int index) const { return vertices[index].val; }
    inline const Vector3AD& getVertexAD(int index) const { return vertices[index]; }
    inline const Vector3& getShadingNormal(int index) const { return normals[index].val; }
    inline const Vector3AD& getShadingNormalAD(int index) const { return normals[index]; }
    inline const Vector3& getGeoNormal(int index) const { return faceNormals[index].val; }
    inline const Vector3AD& getGeoNormalAD(int index) const { return faceNormals[index]; }
    inline const Vector3i& getIndices(int index) const { return indices[index]; }
    inline const Vector2& getUV(int index) const { return uvs[index]; }
    inline Float getEdgeTotLength() const { return edge_distrb.getSum(); }
    inline const Edge& getEdge(int index) const { return edges[index]; }

    Float getArea(int index) const;
    FloatAD getAreaAD(int index) const;

    // void samplePosition(int index, const Vector2 &rnd2, Vector &p, Vector &n) const;

    int samplePosition(const Vector2 &rnd2, PositionSamplingRecord &pRec) const;
    int samplePositionAD(const Vector2 &rnd2, PositionSamplingRecordAD &pRec) const;

    void rayIntersect(int tri_index, const Ray &ray, Intersection& its) const;
    void rayIntersectAD(int tri_index, const RayAD &ray, IntersectionAD& its) const;

    void constructEdges();
    const Edge& sampleEdge(Float& rnd, Float& pdf) const;
    int isSihoulette(const Edge& edge, const Vector& p) const;

    // For path-space diff. rendering
    void getPoint(int tri_index, const Vector2AD &barycentric, VectorAD &x, VectorAD &n, FloatAD &J) const;
    void getPoint(int tri_index, const Vector2AD &barycentric, IntersectionAD& its_AD, FloatAD &J) const;

    // For path-space volume
    void getPoint(Vector p_scatter, VectorAD &x);

    std::vector<Vector3AD> vertices, normals;
    std::vector<Vector3i> indices;
    std::vector<Vector2> uvs;

    // Per triangle properties
    std::vector<Vector3AD> faceNormals;
    //std::vector<Matrix3x3> faceRotations;

    int num_vertices;
    int num_triangles;

    int light_id;
    int bsdf_id;
    int med_int_id;
    int med_ext_id;

    std::vector<Edge> edges;
    DiscreteDistribution edge_distrb;

    DiscreteDistribution face_distrb;
    inline Float getArea() const { return face_distrb.getSum(); }
};

#endif
