#include "pif.h"
#include "utils.h"

Float PIF::eval(Float x, int bin_index) const {
    Float centerTau = tau.val + stepTau * bin_index;
    return evalSingle(x, centerTau);
}

FloatAD PIF::evalAD(const FloatAD &x, int bin_index) const {
    FloatAD centerTau = tau + stepTau * bin_index;
    return evalSingleAD(x, centerTau);
}

Float PIF::sample(RndSampler *sampler, int &bin_index, Float &pdf) const {
    Vector2 rnd2 = sampler->next2D();
    bin_index = (int)std::floor(rnd2.x() * num_bins);
    assert(0 <= bin_index && bin_index < num_bins);

    Vector2 rnd2_reuse(rnd2);
    rnd2_reuse[0] = (rnd2[0] - (Float)bin_index / num_bins) * num_bins;
    Float centerTau = tau.val + stepTau * bin_index;

    Float res = sampleSingle(sampler, rnd2_reuse, centerTau, pdf);
    pdf /= (Float)num_bins;
    return res;
}

Vector2 PIF::sampleAntithetic(RndSampler *sampler, int &bin_index, Float &pdf) const {
    Float path_length_sampled = sample(sampler, bin_index, pdf);
    Float centerTau = tau.val + stepTau * bin_index;

    Vector2 res(path_length_sampled, 2 * centerTau - path_length_sampled);
    return res;
}

Vector2 PIF::sampleBoundary(Float rnd, int &bin_index, Float &pdf) const {
    assert(name == "Boxcar" || name == "TruncatedGaussian");

    bin_index = (int)std::floor(rnd * num_bins);
    assert(0 <= bin_index && bin_index < num_bins);

    Float centerTau = tau.val + stepTau * bin_index;
    return sampleSingleBoundary(centerTau, pdf);
}

int PIF::getBinIndex(Float x) const {
    if (name == "Delta") {
        int bin_index = (int)std::floor((x - tau.val) / stepTau);
        if (bin_index >= 0 && bin_index < num_bins && std::abs(tau.val + bin_index * stepTau - x) < Epsilon)
            return bin_index;
        else
            return -1;
    }

    Float tmp = x - (tau.val - tau_extent);
    if (tmp < 0 || tmp > stepTau * (num_bins - 1) + 2 * tau_extent)
        return -1;

    int bin_index = (int)std::floor(tmp / stepTau);
    Float x_residual = tmp - stepTau * bin_index;
    if (x_residual > 2 * tau_extent)
        return -1;
    else
        return bin_index;
}

int PIF::getPotentialBinIndex(Float x, int which_end) const {
    int res;
    if (which_end == 0)
        res = (int)std::ceil((x - tau.val - tau_extent) / stepTau);
    else
        res = (int)std::floor((x - tau.val + tau_extent) / stepTau);

    return clamp(res, 0, num_bins - 1);
}