#include "t_gaussian.h"
#include "utils.h"

#define TRUNCATED_GAUSSIAN_RANGE 1.0

Float TruncatedGaussianPIF::evalSingle(Float x, Float centerTau) const {
    // Truncate at 1 * sigma
    Float min_tau = centerTau - deltaTau.val * TRUNCATED_GAUSSIAN_RANGE - Epsilon;
    Float max_tau = centerTau + deltaTau.val * TRUNCATED_GAUSSIAN_RANGE + Epsilon;
    Float res = normalPdf(x, centerTau, deltaTau.val) * 2.0 * deltaTau.val;
    return (min_tau < x && x < max_tau) ? res : 0.0;
}

FloatAD TruncatedGaussianPIF::evalSingleAD(const FloatAD &x, const FloatAD &centerTau) const {
    // Truncate at 1 * sigma
    Float min_tau = centerTau.val - deltaTau.val * TRUNCATED_GAUSSIAN_RANGE - Epsilon;
    Float max_tau = centerTau.val + deltaTau.val * TRUNCATED_GAUSSIAN_RANGE + Epsilon;
    FloatAD res = normalPdfAD(x, centerTau, deltaTau) * 2.0 * deltaTau;
    return (min_tau < x.val && x.val < max_tau) ? res : 0.0;
}

Float TruncatedGaussianPIF::sampleSingle(RndSampler *sampler, const Vector2 &rnd2, Float centerTau, Float &pdf) const {
    // Truncate at 1 * sigma
    Float min_tau = centerTau - deltaTau.val * TRUNCATED_GAUSSIAN_RANGE;
    Float max_tau = centerTau + deltaTau.val * TRUNCATED_GAUSSIAN_RANGE;

    Float x;
    while (true) {
        Vector2 uv = squareToStdNormal(sampler->next2D());
        x = uv.x() * deltaTau.val + centerTau;
        if (min_tau <= x && x <= max_tau)
            break;
    }

    Float pdf_norm = normalCdf(max_tau, centerTau, deltaTau.val) - normalCdf(min_tau, centerTau, deltaTau.val);
    pdf = normalPdf(x, centerTau, deltaTau.val) / pdf_norm;
    return x;
}

Vector2 TruncatedGaussianPIF::sampleSingleBoundary(Float centerTau, Float &pdf) const {
    Vector2 res(centerTau - deltaTau.val * TRUNCATED_GAUSSIAN_RANGE, centerTau + deltaTau.val * TRUNCATED_GAUSSIAN_RANGE);
    pdf = 1.0 / (Float)num_bins;
    return res;
}

Vector2i TruncatedGaussianPIF::getBinIndexRange(Float x) const {
    return Vector2i(getPotentialBinIndex(x - TRUNCATED_GAUSSIAN_RANGE * deltaTau.val, 0),
                    getPotentialBinIndex(x + TRUNCATED_GAUSSIAN_RANGE * deltaTau.val, 1));
}