#pragma once
#ifndef PHASE_H__
#define PHASE_H__

#include <core/fwd.h>
#include <core/ptr.h>
#include <core/frame.h>
#include <core/utils.h>
#include <core/object.h>

struct PhaseFunction : Object
{
    virtual ~PhaseFunction() {}
    virtual PhaseFunction *clone() const = 0;
    virtual void merge(PhaseFunction *other) = 0;
    virtual void setZero() = 0;

    Float sample(const Vector &wi, const Vector2 &rnd2, Vector &wo) const;

    Float pdf(const Vector &wi, const Vector &wo) const;

    Float eval(const Vector &wi, const Vector &wo) const;

    PSDR_DECLARE_VIRTUAL_CLASS()

    // --------------------------------------------------
    int m_type = -1;
};

struct HGPhaseFunction : PhaseFunction
{
    inline HGPhaseFunction(Float g) : g(g) 
    {
        this->m_type = TYPE_ID;
    }
    PhaseFunction *clone() const override
    {
        return new HGPhaseFunction(*this);
    }
    void merge(PhaseFunction *other) override
    {
        g += dynamic_cast<HGPhaseFunction *>(other)->g;
    }
    void setZero() override
    {
        g = 0.;
    }

    Float eval(const Vector &wi, const Vector &wo) const
    {
        Float temp = 1.0f + g * g + 2.0f * g * wi.dot(wo);
        return INV_FOURPI * (1.0f - g * g) / (temp * std::sqrt(temp));
    }

    Float sample(const Vector &wi, const Vector2 &rnd2, Vector &wo) const
    {
        Float cosTheta;
        if (std::abs(g) < Epsilon)
        {
            cosTheta = 1.0f - 2.0f * rnd2.x();
        }
        else
        {
            Float sqrTerm = (1.0f - g * g) / (1.0f - g + 2.0f * g * rnd2.x());
            cosTheta = (1.0f + g * g - sqrTerm * sqrTerm) / (2.0f * g);
        }
        Float sinTheta = std::sqrt(1.0f - cosTheta * cosTheta);
        Float sinPhi = std::sin(2.0f * M_PI * rnd2.y()),
              cosPhi = std::cos(2.0f * M_PI * rnd2.y());
        wo = Frame(-wi).toWorld(Vector(sinTheta * cosPhi, sinTheta * sinPhi, cosTheta));
        return 1.0f;
    }

    Float pdf(const Vector &wi, const Vector &wo) const { return eval(wi, wo); }

    Float g;

    static const int TYPE_ID = 1;

    PSDR_DECLARE_CLASS(HGPhaseFunction)
    PSDR_IMPLEMENT_VIRTUAL_CLASS(HGPhaseFunction)
};

struct IsotropicPhaseFunction : PhaseFunction
{
    IsotropicPhaseFunction() 
    {
        this->m_type = TYPE_ID;
    }
    void setZero() override
    {
    }
    PhaseFunction *clone() const override
    {
        return new IsotropicPhaseFunction();
    }
    void merge(PhaseFunction *other) override{};
    Float eval(const Vector &wi, const Vector &wo) const { return INV_FOURPI; }
    Float pdf(const Vector &wi, const Vector &wo) const { return INV_FOURPI; }

    Float sample(const Vector &wi, const Vector2 &rnd2, Vector &wo) const
    {
        wo = squareToUniformSphere(rnd2);
        return 1.0f;
    }

    static const int TYPE_ID = 2;

    PSDR_DECLARE_CLASS(IsotropicPhaseFunction)
    PSDR_IMPLEMENT_VIRTUAL_CLASS(IsotropicPhaseFunction)
};

#endif