#include <render/common.h>
#include <render/intersection.h>
#include <core/math_func.h>
#include <core/logger.h>

DirectSamplingRecord::DirectSamplingRecord(const Intersection &refIts)
    : ref(refIts.p), refN(Vector::Zero())
{
    if (!refIts.getBSDF()->isTransmissive() && !refIts.getBSDF()->isTwosided())
        refN = refIts.geoFrame.n;
}

__attribute__((noinline)) Float sinB(const Vector &xS,
                                     const Vector &xB, const Vector &v,
                                     const Vector &xD)
{
    Vector xS_xD = xD - xS;
    Vector dir = xS_xD.normalized();
    return v.cross(dir).norm();
}

__attribute__((noinline)) Float sinD(const Vector &xS,
                                     const Vector &xB, const Vector &v,
                                     const Vector &xD, const Vector &n)
{
    Vector xS_xD = xD - xS;
    Vector dir = xS_xD.normalized();
    Vector proj = project(v, dir, n);
    return dir.cross(proj).norm();
}

__attribute__((noinline)) Float dlS_dlB(const Vector &xD,
                                        const Vector &xB, const Vector &v,
                                        const Vector &xS, const Vector &n)
{
    Vector xD_xS = xS - xD;
    Vector xD_xB = xB - xD;
    Vector dir = xD_xS.normalized();
    Vector proj = project(v, dir, n);
    Float sinB = v.cross(dir).norm();
    Float sinD = dir.cross(proj).norm();
    Float J = 0.;
    if (sinB > Epsilon && sinD > Epsilon)
    {
        J = xD_xS.norm() / xD_xB.norm() * sinB / sinD;
    }
    return J;
}

__attribute__((noinline)) Float dlD_dlB(const Vector &xS,
                                        const Vector &xB, const Vector &v,
                                        const Vector &xD, const Vector &n)
{
    return dlS_dlB(xS,
                   xB, v,
                   xD, n);
}

// paper p.9 eq.41
__attribute__((noinline)) Float dASdAD_dlBdwBdrD(const Vector &xS, const Vector &nS,
                                                 const Vector &xB, const Vector &v,
                                                 const Vector &xD)
{
    Vector xS_xB = xB - xS;
    Vector xS_xD = xD - xS;
    Vector dir = xS_xD.normalized();
    Float sinB = v.cross(dir).norm();
    Float cosS = nS.dot(dir);
    return xS_xB.norm() * xS_xD.norm() * sinB / abs(cosS);
}

__attribute__((noinline)) Float dA_dw(const Vector &ref,
                                      const Vector &p, const Vector &n)
{
    return (p - ref).squaredNorm() / abs(n.dot(-(p - ref).normalized()));
}

Float geometric(const Vector &ref,
                const Vector &p)
{
    Vector dir = p - ref;
    Float dist = dir.norm();
    dir /= dist;
    return 1. / (dist * dist);
}

Float geometric(const Vector &ref,
                const Vector &p, const Vector &n)
{
    Vector dir = p - ref;
    Float dist = dir.norm();
    dir /= dist;
    return abs(n.dot(-dir)) / (dist * dist);
}

Float geometric(const Vector &ref, const Vector &refN,
                const Vector &p, const Vector &n)
{
    Vector dir = p - ref;
    Float dist = dir.norm();
    dir /= dist;
    return abs(refN.dot(dir) * n.dot(-dir)) / (dist * dist);
}

INACTIVE_FN(dlS_dlB, dlS_dlB);
INACTIVE_FN(dlD_dlB, dlD_dlB);
INACTIVE_FN(dA_dw, dA_dw);

Vector rayIntersectTriangle(const Vector &v0, const Vector &v1, const Vector &v2, const Ray &ray, IntersectionMode mode)
{
    Vector uvt = rayIntersectTriangle(v0, v1, v2, ray);
    if (mode == IntersectionMode::EMaterial)
    {
        Float u = detach(uvt(0)),
              v = detach(uvt(1));
        return (1 - u - v) * v0 +
               u * v1 + v * v2;
    }
    if (mode == IntersectionMode::ESpatial)
    {
        return ray.org + ray.dir * uvt(2);
    }
    if (mode == IntersectionMode::EReference)
    {
        Float u = uvt(0), v = uvt(1);
        return (1 - u - v) * detach(v0) +
               u * detach(v1) +
               v * detach(v2);
    }
    PSDR_ASSERT_MSG(false, "Invalid intersection mode");
}

Vector rayIntersectTriangle2(const Vector &v0, const Vector &v1, const Vector &v2, const Ray &ray, IntersectionMode mode)
{
    return rayIntersectTriangle(v0, v1, v2, ray, mode);
}

Vector face_normal(const Vector &v0, const Vector &v1, const Vector &v2)
{
    Vector u = v1 - v0;
    Vector v = v2 - v0;
    return u.cross(v).normalized();
}

// xS on surface
Float normal_velocity(const Vector &xD,
                      const Vector &xB, const Vector &nB,
                      const Vector &xS, const Vector &n)
{
    Vector u = nB.cross(n);
    u = u.cross(n).normalized();
    u *= math::signum(u.dot(nB)); // make sure the u points to the visible side
    Float xS_proj = xS.dot(detach(u));
    return xS_proj - detach(xS_proj); // val = 0, der != 0
}

Float normal_velocity(const Vector &xD,
                      const Vector &xB, const Vector &v,
                      const Vector &xS_0, const Vector &xS_1, const Vector &xS_2,
                      const Vector &n)
{
    Ray ray(xB, (xB - xD).normalized());
    Vector xS = rayIntersectTriangle(xS_0, xS_1, xS_2, ray, IntersectionMode::EReference);
    Vector nS = face_normal(xS_0, xS_1, xS_2);
    return -normal_velocity(xD,
                            xB, detach(n),
                            xS, detach(nS));
}

// xS in volume
Float normal_velocity(const Vector &nB,
                      const Vector &xS)
{
    Float xS_proj = xS.dot(detach(nB));
    return xS_proj - detach(xS_proj); // val = 0, der != 0
}

Float normal_velocity(const Vector &xS_0, const Vector &xS_1, const Vector &xS_2,
                      const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float t, const Vector &dir,
                      const Vector &xD_0, const Vector &xD_1, const Vector &xD_2)
{
    Vector xB = xB_0 + (xB_1 - xB_0) * t;
    // rayIntersect and getPoint
    Vector xS = rayIntersectTriangle(xS_0, xS_1, xS_2, Ray(xB, dir),
                                     IntersectionMode::EMaterial);
    Vector nB = (xB_0 - xB_1).cross(xB - xS).normalized();
    nB *= -math::signum(nB.dot(xB_2 - xB_0)); // make sure n points to the visible side
    Ray ray(xB, (xB - xS).normalized());
    Vector xD = rayIntersectTriangle(xD_0, xD_1, xD_2, ray, IntersectionMode::EReference);
    return -normal_velocity(xS,
                            xB, detach(nB),
                            xD, detach(face_normal(xD_0, xD_1, xD_2)));
}

// compatible with point lights
Float normal_velocity(const Vector &xS,
                      const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float t, const Vector &dir,
                      const Vector &xD_0, const Vector &xD_1, const Vector &xD_2)
{
    Vector xB = xB_0 + (xB_1 - xB_0) * t;
    Vector nB = (xB_0 - xB_1).cross(xB - xS).normalized();
    nB *= -math::signum(nB.dot(xB_2 - xB_0)); // make sure n points to the visible side
    Ray ray(xB, (xB - xS).normalized());
    Vector xD = rayIntersectTriangle(xD_0, xD_1, xD_2, ray, IntersectionMode::EReference);
    return -normal_velocity(xS,
                            xB, detach(nB),
                            xD, face_normal(xD_0, xD_1, xD_2));
}

// vertex -> surface vertex
Float normal_velocity(const Vector &xS,
                      const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float t,
                      const Vector &xD_0, const Vector &xD_1, const Vector &xD_2)
{
    Vector xB = xB_0 + (xB_1 - xB_0) * t;
    Vector nB = (xB_0 - xB_1).cross(xB - xS).normalized();
    nB *= -math::signum(nB.dot(xB_2 - xB_0)); // make sure n points to the visible side
    Ray ray(xB, (xB - xS).normalized());
    Vector xD = rayIntersectTriangle(xD_0, xD_1, xD_2, ray, IntersectionMode::EReference);
    return -normal_velocity(xS,
                            xB, detach(nB),
                            xD, face_normal(xD_0, xD_1, xD_2));
}

Float normal_velocity_pixel(const Vector &xS,
                            const Vector &xB, const Vector &nB,
                            const Vector &xD_0, const Vector &xD_1, const Vector &xD_2)
{
    Ray ray(xB, (xB - xS).normalized());
    Vector xD = rayIntersectTriangle(xD_0, xD_1, xD_2, ray, IntersectionMode::EReference);
    return -normal_velocity(xS,
                            xB, detach(nB),
                            xD, face_normal(xD_0, xD_1, xD_2));
}

namespace
{
    double tetra_vol(const Vector3d &a, const Vector3d &b, const Vector3d &c, const Vector3d &d)
    {
        return std::abs((a - d).dot((b - d).cross(c - d))) / 6.;
    }
}
Vector4 getBarycentric(const Vector &p,
                       const Vector3d &a, const Vector3d &b, const Vector3d &c, const Vector3d &d)
{
    Vector4 bary = Vector4::Zero();
    bary[0] = tetra_vol(b, c, d, p);
    bary[1] = tetra_vol(a, c, d, p);
    bary[2] = tetra_vol(a, b, d, p);
    bary[3] = tetra_vol(a, b, c, p);
    bary /= tetra_vol(a, b, c, d);
    return bary;
}
// vertex -> volume vertex
Float normal_velocity(const Vector &xS,
                      const Vector &xB_0, const Vector &xB_1, const Vector &xB_2, const Float t,
                      const Vector &xD_0, const Vector &xD_1, const Vector &xD_2, const Vector &xD_3, const Float dist)
{
    // getPoint xB
    Vector xB = xB_0 + (xB_1 - xB_0) * t;
    Vector nB = (xB_0 - xB_1).cross(xB - xS).normalized();
    nB *= -math::signum(nB.dot(xB_2 - xB_0)); // make sure n points to the visible side
    Ray ray(xS, (xB - xS).normalized());
    // trace xD and project xD back to reference
    Vector xD = ray(dist);
    Vector4 bary = getBarycentric(xD, xD_0, xD_1, xD_2, xD_3);
    assert(bary.cwiseAbs().sum() > 0.99 && bary.cwiseAbs().sum() < 1.01);
    xD = detach(xD_0) * bary[0] +
         detach(xD_1) * bary[1] +
         detach(xD_2) * bary[2] +
         detach(xD_3) * bary[3];

    return -normal_velocity(detach(nB),
                            xD);
}

#if 1
void __normal_velocity(const BoundarySegmentInfo &seg,
                       const Vector &xB_2, const Float &t, const Vector &dir,
                       Float &res)
{
    res = normal_velocity(seg.xS_0, seg.xS_1, seg.xS_2,
                          seg.xB_0, seg.xB_1, xB_2, t, dir,
                          seg.xD_0, seg.xD_1, seg.xD_2);
}

void __normal_velocity_pixel(const PixelBoundarySegmentInfo &seg,
                             const Vector &nB, Float &res)
{
    res = normal_velocity_pixel(seg.xD,
                                seg.xB, nB,
                                seg.xS_0, seg.xS_1, seg.xS_2);
}

// gradient backpropagation: d_res -> d_seg
void d_normal_velocity(const BoundarySegmentInfo &seg, BoundarySegmentInfo &d_seg,
                       const Vector &xB_2, const Float t, const Vector &dir,
                       const Float d_res)
{
    [[maybe_unused]] Float res;
#if defined(ENZYME)
    __enzyme_autodiff((void *)__normal_velocity,
                      enzyme_dup, &seg, &d_seg,
                      enzyme_const, &xB_2,
                      enzyme_const, &t,
                      enzyme_const, &dir,
                      enzyme_dup, &res, &d_res);
#endif
}

// gradient backpropagation: d_res -> d_seg
void d_normal_velocity_pixel(const PixelBoundarySegmentInfo &seg, PixelBoundarySegmentInfo &d_seg,
                             const Vector &nB,
                             const Float d_res)
{
    [[maybe_unused]] Float res;
#if defined(ENZYME)
    __enzyme_autodiff((void *)__normal_velocity_pixel,
                      enzyme_dup, &seg, &d_seg,
                      enzyme_const, &nB,
                      enzyme_dup, &res, &d_res);
#endif
}

#endif
// ====================================================================
bool verbose = false;

bool get_verbose()
{
    return verbose;
}

void set_verbose(bool v)
{
    verbose = v;
}
// ====================================================================
bool forward = false;

bool get_forward()
{
    return forward;
}

void set_forward(bool v)
{
    forward = v;
}