#include "roughdielectric.h"
#include <render/intersection.h>
#include <core/utils.h>
#include "microfacet.h"
#include <assert.h>
#include "light_path.h"
BSDF *RoughDielectricBSDF::clone() const
{
    return new RoughDielectricBSDF(*this);
}
Spectrum RoughDielectricBSDF::eval(const Intersection &its, const Vector &wo, EBSDFMode mode) const
{
    if (std::abs(Frame::cosTheta(its.wi)) < Epsilon || std::abs(Frame::cosTheta(wo)) < Epsilon)
        return Spectrum::Zero();

    Float correct = 1.0f;
    if (mode == EBSDFMode::EImportanceWithCorrection)
        correct *= correction(its, wo);

    /* Determine the type of interaction */
    bool reflect = Frame::cosTheta(its.wi) * Frame::cosTheta(wo) > 0;

    Vector H;
    if (reflect)
    {
        /* Calculate the reflection half-vector */
        H = (wo + its.wi).normalized();
    }
    else
    {
        /* Calculate the transmission half-vector */
        Float eta = m_eta;
        if (Frame::cosTheta(its.wi) > 0)
            eta = m_eta;
        else
            eta = 1.0f / m_eta;
        // const auto &eta = Frame::cosTheta(its.wi) > 0 ? m_eta : 1.0f / m_eta;
        H = (its.wi + wo * eta).normalized();
    }
    /* Ensure that the half-vector points into the
       same hemisphere as the macrosurface normal */
    H *= math::signum(Frame::cosTheta(H));

    /* Evaluate the microfacet normal distribution */
    const Float D = m_distr.eval(H);
    if (std::abs(D) < Epsilon)
        return Spectrum::Zero();

    /* Fresnel factor */
    const Float F = fresnelDielectricExt(its.wi.dot(H), m_eta);

    /* Smith's shadow-masking function */
    const Float G = m_distr.G(its.wi, wo, H);

    Spectrum ret = Spectrum::Zero();
    if (reflect)
    {
        /* Calculate the total amount of reflection */
        ret = m_specularReflectance * (F * D * G / abs(4.0f * Frame::cosTheta(its.wi)));
    }
    else
    {
        Float eta = m_eta;
        if (Frame::cosTheta(its.wi) > 0.)
            eta = m_eta;
        else
            eta = 1.0f / m_eta;
        // const auto &eta = Frame::cosTheta(its.wi) > 0.0f ? m_eta : m_invEta;

        /* Calculate the total amount of transmission */
        Float sqrtDenom = its.wi.dot(H) + eta * wo.dot(H);
        Float value = ((1.0f - F) * D * G * eta * eta * its.wi.dot(H) * wo.dot(H)) / (Frame::cosTheta(its.wi) * sqrtDenom * sqrtDenom);
        if (mode == EBSDFMode::ERadiance)
            ret = m_specularTransmittance * (abs(value));
        else
        {
            Float factor = m_eta;
            if (Frame::cosTheta(its.wi) > 0)
                factor = 1.0f / m_eta;
            else
                factor = m_eta;
            // Float factor = Frame::cosTheta(its.wi) > 0 ? 1.0f / m_eta : m_eta;
            ret = m_specularTransmittance * (abs(value * factor * factor));
        }
    }

    return ret * correct;
}

Spectrum RoughDielectricBSDF::sample(const Intersection &its, const Array3 &rnd, Vector &wo, Float &pdf, Float &eta, EBSDFMode mode) const
{
    Array2 sample(rnd[0], rnd[1]);

    /* Sample M, the microfacet normal */
    Float microfacetPDF;
    const Vector m = m_distr.sample(math::signum(Frame::cosTheta(its.wi)) * its.wi, sample, microfacetPDF);
    if (microfacetPDF < Epsilon)
        return Spectrum::Zero();
    pdf = microfacetPDF;

    Float cosThetaT;
    Float F = fresnelDielectricExt(its.wi.dot(m), cosThetaT, m_eta);
    Spectrum weight = Spectrum::Ones();

    bool sampleReflection;
    if (rnd[2] > F)
    {
        sampleReflection = false;
        pdf *= 1.0f - F;
    }
    else
    {
        sampleReflection = true;
        pdf *= F;
    }

    Float dwh_dwo;
    if (sampleReflection)
    {
        /* Perfect specular reflection based on the microfacet normal */
        wo = reflect(its.wi, m);
        eta = 1.0f;

        /* Side check */
        if (Frame::cosTheta(its.wi) * Frame::cosTheta(wo) < Epsilon)
            return Spectrum::Zero();

        /* Jacobian of the half-direction mapping */
        dwh_dwo = 1.0f / (4.0f * wo.dot(m));
        weight *= m_specularReflectance;
    }
    else
    {
        if (std::abs(cosThetaT) < Epsilon)
            return Spectrum::Zero();

        /* Perfect specular transmission based on the microfacet normal */
        wo = refract(its.wi, m, m_eta, cosThetaT);
        eta = cosThetaT < 0 ? m_eta : m_invEta;

        /* Side check */
        if (Frame::cosTheta(its.wi) * Frame::cosTheta(wo) >= 0)
            return Spectrum::Zero();

        /* Radiance must be scaled to account for the solid angle compression
           that occurs when crossing the interface. */
        Float factor = (mode == EBSDFMode::ERadiance) ? 1.0f : (cosThetaT < 0 ? m_invEta : m_eta);

        weight *= m_specularTransmittance * (factor * factor);

        /* Jacobian of the half-direction mapping */
        Float sqrtDenom = its.wi.dot(m) + eta * wo.dot(m);
        dwh_dwo = (eta * eta * wo.dot(m)) / (sqrtDenom * sqrtDenom);
    }

    weight *= m_distr.smithG1(wo, m);
    pdf *= std::abs(dwh_dwo);

    if (mode == EBSDFMode::EImportanceWithCorrection)
        weight *= correction(its, wo);

    return weight;
}

Float RoughDielectricBSDF::pdf(const Intersection &its, const Vector &wo) const
{
    bool reflect = Frame::cosTheta(its.wi) * Frame::cosTheta(wo) > 0;

    Vector H;
    Float dwh_dwo;

    if (reflect)
    {
        /* Calculate the reflection half-vector */
        H = (its.wi + wo).normalized();

        /* Jacobian of the half-direction mapping */
        dwh_dwo = 1.0f / (4.0f * wo.dot(H));
    }
    else
    {
        /* Calculate the transmission half-vector */
        Float eta = Frame::cosTheta(its.wi) > 0 ? m_eta : m_invEta;

        H = (its.wi + eta * wo).normalized();

        /* Jacobian of the half-direction mapping */
        Float sqrtDenom = its.wi.dot(H) + eta * wo.dot(H);
        dwh_dwo = (eta * eta * wo.dot(H)) / (sqrtDenom * sqrtDenom);
    }

    /* Ensure that the half-vector points into the
       same hemisphere as the macrosurface normal */
    H *= math::signum(Frame::cosTheta(H));

    /* Evaluate the microfacet model sampling density function */
    Float prob = m_distr.pdf(math::signum(Frame::cosTheta(its.wi)) * its.wi, H);

    Float F = fresnelDielectricExt(its.wi.dot(H), m_eta);
    prob *= reflect ? F : (1.0f - F);

    return std::abs(prob * dwh_dwo);
}

IMPLEMENT_BSDF_HELPER_FUNCTIONS(RoughDielectricBSDF)
