#pragma once
#ifndef MEDIUM_H__
#define MEDIUM_H__

#include "fwd.h"
#include "tetra.hpp"
struct Ray;
struct RayAD;
struct RndSampler;

struct Medium
{
    Medium(int phase_id) : phase_id(phase_id)
    {
        // m_volumeToWorld_R = Matrix3x3AD(Matrix3x3::Identity(), Matrix3x3::Zero());
        // m_volumeToWorld_T = VectorAD(Vector(0., 0., 0.), Vector(0., 1., 0.));
        m_transVelocities.setZero();
        m_rotVelocities.setZero();
    }

    virtual ~Medium() {}

    virtual bool sampleDistance(const Ray &ray, const Float &tmax, const Array2 &rnd2, RndSampler *sampler,
                                Vector &p_scatter, Spectrum &throughput) const = 0;

    virtual bool sampleDistanceAD(const RayAD &ray, const FloatAD &tmax, const Array2 &rnd2, RndSampler *sampler,
                                  VectorAD &p_scatter, SpectrumAD &throughput) const
    {
        assert(false);
        return false;
    }

    virtual Float evalTransmittance(const Ray &ray, const Float &tmin, const Float &tmax, RndSampler *sampler) const = 0;

    virtual FloatAD evalTransmittanceAD(const RayAD &ray, const FloatAD &tmin, const FloatAD &tmax, RndSampler *sampler) const
    {
        assert(false);
        return FloatAD();
    }
    virtual FloatAD evalTransmittanceRatioAD(const RayAD &ray, const FloatAD &tmin, const FloatAD &tmax, RndSampler *sampler) const
    {
        FloatAD T = evalTransmittanceAD(ray, tmin, tmax, sampler);
        return T/T.val;
    }

    virtual bool isHomogeneous() const = 0;

    virtual Spectrum sigS(const Vector &x) const = 0;
    virtual Float sigT(const Vector &x) const = 0;
    virtual SpectrumAD sigSAD(const VectorAD &x) const = 0;
    // Vector volumeToWorld(const Vector &p) const
    // {
    //     return m_volumeToWorld_R.val * p + m_volumeToWorld_T.val;
    // }

    //! J
    bool getPoint(const Vector &p, VectorAD &x, FloatAD &J) const
    {
        // x = m_volumeToWorld_R * VectorAD(p) + m_volumeToWorld_T;
        // assert(x.val==p);

        // x.val = p;
        // for (int i = 0; i < nder; ++i)
        //     x.grad(i) = m_rotVelocities.col(i).cross(p) + m_transVelocities.col(i);
        // J = FloatAD(1.0);

        return tet_ptr->query(p, x, J);
    }

    int phase_id;
    mutable std::unique_ptr<TetrahedronMesh> tet_ptr;
    // rotation and translation.
    Eigen::Matrix<Float, 3, nder> m_transVelocities, m_rotVelocities;
    // Matrix3x3AD m_volumeToWorld_R;
    // VectorAD m_volumeToWorld_T;
};

#endif //MEDIUM_H__
