#include "medium/homogeneous.h"

#include <core/ray.h>
#include <core/sampler.h>

bool Homogeneous::sampleDistance(const Ray &ray, const Float &tmax,
                                 const Array2 &rnd2, RndSampler *sampler,
                                 Vector &p_scatter, Spectrum &throughput) const {
    assert(false);
    if (sampling_weight > Epsilon) {
        Float t = -std::log(rnd2.x()) / sigma_t;
        if (t < tmax) {
            p_scatter = ray(t);
            if (rnd2.y() < sampling_weight)
                throughput *= albedo.lookupSpectrum(p_scatter) / sampling_weight;
            else
                throughput = Spectrum::Zero();
            return true;
        } else
            return false;
    } else {
        throughput *= std::exp(-tmax * sigma_t);
        return false;
    }
}

bool Homogeneous::sampleDistance(const Ray &ray, const Float &tmax,
                                 RndSampler           *sampler,
                                 MediumSamplingRecord &mRec) const {
    // sample distance
    Float rnd             = sampler->next1D();
    Float sampledDistance = 0;
    if (rnd < sampling_weight) // sampling_weight = albedo
    {
        rnd /= sampling_weight;
        sampledDistance = -std::log(rnd) / sigma_t; // p(t) = simga_s * exp(-t*sigma_t)
    } else {
        sampledDistance = std::numeric_limits<Float>::infinity();
    }

    // store the data
    bool success = true;

    if (sampledDistance < tmax) {
        mRec.t          = sampledDistance;
        mRec.p          = ray(mRec.t);
        mRec.wi         = -ray.dir;
        mRec.pdfSuccess = sigma_t / sampling_weight;
        mRec.sigmaS     = sigS(mRec.p);
        mRec.sigmaT     = sigma_t;
        mRec.medium     = this;
        assert(mRec.p != ray.org);
        success = true;
    } else {
        mRec.t  = tmax;
        mRec.p  = ray(mRec.t);
        success = false;
    }

    mRec.transmittance = exp(sigma_t * (-mRec.t));
    mRec.pdfSuccess    = sampling_weight * sigma_t * mRec.transmittance;     // p(t) = sigma_s * exp(-t*sigma_t)
    mRec.pdfFailure    = 1 - sampling_weight * (1 - exp(sigma_t * (-tmax))); // p = 1 - sigma_s / sigma_t * (1 - exp(-t*sigma_t))
    mRec.medium        = this;

    return success;
}

Float Homogeneous::evalTransmittance(const Ray &ray, const Float &tmin,
                                     const Float &tmax, RndSampler *sampler) const {
    return std::exp((tmin - tmax) * sigma_t);
}

Float Homogeneous::evalTransmittanceRatio(const Ray &ray, const Float &tmin,
                                          const Float &tmax, RndSampler *sampler) const {
    Float T = evalTransmittance(ray, tmin, tmax, sampler);
    return T / detach(T);
}

PSDR_IMPL_MEDIUM_HELPER_FUNCTIONS(Homogeneous)