#include <core/logger.h>
#include <core/transform.h>
#include <render/common.h>
#include <render/volumegrid.h>

#include <Eigen/Dense>
#include <algorithm>

void VolumeGrid::setZero() {
    std::fill(m_data.begin(), m_data.end(), 0.);
    std::fill(m_channel_max.begin(), m_channel_max.end(), 0.);
    m_max = 0;
}

void VolumeGrid::merge(const VolumeGrid &other) {
    merge_vector(m_data, other.m_data);
}

VolumeGrid::VolumeGrid(const Properties &props) {
    if (props.get<std::string>("type") == "gridvolume") {
        if (props.has("filename")) {
            is_constant = false;
            FileStream stream(props.get<std::string>("filename"));
            read(stream);
        } else {
            PSDR_WARN("VolumeGrid: no filename specified");
        }
    } else if (props.get<std::string>("type") == "constvolume") {
        if (props.has("value")) {
            is_constant     = true;
            Spectrum albedo = props.get<Spectrum>("value");
            m_data.resize(albedo.size());
            m_data[0] = albedo.x();
            m_data[1] = albedo.y();
            m_data[2] = albedo.z();
        } else {
            PSDR_WARN("VolumeGrid: no value specified");
        }
    } else {
        PSDR_WARN("Unknown volume grid type: %s", props.get<std::string>("type").c_str());
    }

    m_volumeToWorld = props.get<Matrix4x4>("toWorld", Matrix4x4::Identity());
    configure(); // FIXME
}

VolumeGrid::VolumeGrid(const Vector3i &res, int nchannel,
                       const Vector &min, const Vector &max,
                       const std::vector<Float> &data,
                       const Matrix4x4          &toWorld) {
    is_constant     = false;
    m_res           = res;
    m_nchannel      = nchannel;
    m_data          = data;
    m_bbox          = AABB(min, max);
    m_volumeToWorld = toWorld;
    configure();
}

VolumeGrid::VolumeGrid(const Spectrum &value) {
    is_constant = true;
    m_data.resize(value.size());
    m_data[0] = value.x();
    m_data[1] = value.y();
    m_data[2] = value.z();
}

VolumeGrid::VolumeGrid(const std::string &filename) {
    fs::path path = filename;
    if (!fs::exists(path)) {
        Throw("VolumeGrid: file not found: {}", path.string());
    }
    FileStream stream(path);
    read(stream);
    configure();
}

VolumeGrid::VolumeGrid(const fs::path &path) {
    FileStream stream(path);
    read(stream);
    configure();
}

VolumeGrid::VolumeGrid(Stream &stream) {
    read(stream);
    configure();
}

void VolumeGrid::configure() {
    if (is_constant)
        return;
    // compute m_max
    m_max = -std::numeric_limits<Float>::max();
    for (int i = 0; i < m_data.size(); i++) {
        m_max = std::max(m_max, m_data[i]);
    }
    // compute m_channel_max
    m_channel_max.resize(m_nchannel, -std::numeric_limits<Float>::max());
    for (size_t i = 0; i < m_res.prod(); ++i) {
        for (size_t j = 0; j < m_nchannel; ++j) {
            m_channel_max[j] = std::max(m_channel_max[j],
                                        m_data[i * m_nchannel + j]);
        }
    }

    m_worldToVolume = m_volumeToWorld.inverse();
    m_volumeToGrid  = volumeToGrid();
    m_worldToGrid   = m_volumeToGrid * m_worldToVolume;
}

void VolumeGrid::read(Stream &stream) {
    char header[3];
    stream.read(header, 3);
    // PSDR_INFO("header: {}, {}, {}", header[0], header[1], header[2]);
    if (header[0] != 'V' || header[1] != 'O' || header[2] != 'L')
        Throw("Invalid volume file!");

    uint8_t version;
    stream.read(version);
    // PSDR_INFO("version: {}", version);
    if (version != 3)
        Throw("Invalid version, currently only version 3 is supported (found "
              "{}) ",
              version);

    int32_t data_type;
    stream.read(data_type);
    // endian conversion
    // PSDR_INFO("data_type: {}", data_type);
    if (data_type != 1)
        Throw("Wrong type, currently only type == 1 (Float32) data is "
              "supported (found type = %d)",
              data_type);

    int32_t size_x, size_y, size_z;
    stream.read(size_x);
    stream.read(size_y);
    stream.read(size_z);
    m_res.x() = uint32_t(size_x);
    m_res.y() = uint32_t(size_y);
    m_res.z() = uint32_t(size_z);
    // PSDR_INFO("size: {}, {}, {}", m_res.x(), m_res.y(), m_res.z());

    size_t  size = m_res.prod();
    int32_t channel_count;
    stream.read(channel_count);
    m_nchannel = channel_count;
    // PSDR_INFO("channel_count: {}", channel_count);

    float dims[6];
    stream.read_array(dims, 6);
    m_bbox = AABB(Vector(dims[0], dims[1], dims[2]),
                  Vector(dims[3], dims[4], dims[5]));

    m_max = -std::numeric_limits<Float>::max();
    m_channel_max.resize(m_nchannel, -std::numeric_limits<Float>::max());

    m_data   = std::vector<Float>(size * m_nchannel);
    size_t k = 0;
    for (size_t i = 0; i < size; ++i) {
        for (size_t j = 0; j < m_nchannel; ++j) {
            float val;
            stream.read(val);
            m_data[k]        = val;
            m_max            = std::max(m_max, Float(val));
            m_channel_max[j] = std::max(m_channel_max[j], Float(val));
            ++k;
        }
    }
}

Matrix4x4 VolumeGrid::worldToVolume() const {
    return m_worldToVolume;
}

Matrix4x4 VolumeGrid::volumeToGrid() const {
    Vector extents = m_bbox.getExtents();
    return psdr::transform::scale({ (m_res[0] - 1) / extents[0],
                                    (m_res[1] - 1) / extents[1],
                                    (m_res[2] - 1) / extents[2] }) *
           psdr::transform::translate(-m_bbox.min);
}

Matrix4x4 VolumeGrid::worldToGrid() const {
    return m_worldToGrid;
}

Vector VolumeGrid::toGrid(const Vector &x) const {
    return psdr::transform_pos(m_worldToGrid, x);
}

Vector VolumeGrid::toLocal(const Vector &x) const {
    return psdr::transform_pos(m_worldToVolume, x);
}

Float VolumeGrid::lookupFloat(const Vector &_p) const {
    // FIXME
    if (is_constant)
        return m_data.at(0);

    Vector p = toGrid(_p);
    if (m_nchannel != 1)
        assert(false);
    const int x1 = static_cast<int>(std::floor(p.x()));
    const int y1 = static_cast<int>(std::floor(p.y()));
    const int z1 = static_cast<int>(std::floor(p.z()));
    int       x2 = x1 + 1, y2 = y1 + 1, z2 = z1 + 1;

    if (x1 < 0 || y1 < 0 || z1 < 0 || x2 >= m_res.x() || y2 >= m_res.y() ||
        z2 >= m_res.z())
        return 0;
    const Float fx = p.x() - x1, fy = p.y() - y1, fz = p.z() - z1,
                _fx = 1.0f - fx, _fy = 1.0f - fy, _fz = 1.0f - fz;
    const Float *data = (Float *) m_data.data();
    const Float  d000 = data[(z1 * m_res.y() + y1) * m_res.x() + x1],
                d001  = data[(z1 * m_res.y() + y1) * m_res.x() + x2],
                d010  = data[(z1 * m_res.y() + y2) * m_res.x() + x1],
                d011  = data[(z1 * m_res.y() + y2) * m_res.x() + x2],
                d100  = data[(z2 * m_res.y() + y1) * m_res.x() + x1],
                d101  = data[(z2 * m_res.y() + y1) * m_res.x() + x2],
                d110  = data[(z2 * m_res.y() + y2) * m_res.x() + x1],
                d111  = data[(z2 * m_res.y() + y2) * m_res.x() + x2];
    return ((d000 * _fx + d001 * fx) * _fy + //
            (d010 * _fx + d011 * fx) * fy) * //
               _fz +
           ((d100 * _fx + d101 * fx) * _fy + //
            (d110 * _fx + d111 * fx) * fy) * //
               fz;
}

Spectrum VolumeGrid::lookupSpectrum(const Vector &_p) const {
    if (is_constant)
        return Spectrum(m_data.at(0), m_data.at(1), m_data.at(2));

    Vector p = toGrid(_p);
    if (m_nchannel != 3)
        assert(false);
    const int x1 = static_cast<int>(std::floor(p.x()));
    const int y1 = static_cast<int>(std::floor(p.y()));
    const int z1 = static_cast<int>(std::floor(p.z()));
    int       x2 = x1 + 1, y2 = y1 + 1, z2 = z1 + 1;

    if (x1 < 0 || y1 < 0 || z1 < 0 || x2 >= m_res.x() || y2 >= m_res.y() ||
        z2 >= m_res.z())
        return Spectrum(0.);
    const Float fx = p.x() - x1, fy = p.y() - y1, fz = p.z() - z1,
                _fx = 1.0f - fx, _fy = 1.0f - fy, _fz = 1.0f - fz;
    const Spectrum *data = (Spectrum *) m_data.data();
    const Spectrum  d000 = data[(z1 * m_res.y() + y1) * m_res.x() + x1],
                   d001  = data[(z1 * m_res.y() + y1) * m_res.x() + x2],
                   d010  = data[(z1 * m_res.y() + y2) * m_res.x() + x1],
                   d011  = data[(z1 * m_res.y() + y2) * m_res.x() + x2],
                   d100  = data[(z2 * m_res.y() + y1) * m_res.x() + x1],
                   d101  = data[(z2 * m_res.y() + y1) * m_res.x() + x2],
                   d110  = data[(z2 * m_res.y() + y2) * m_res.x() + x1],
                   d111  = data[(z2 * m_res.y() + y2) * m_res.x() + x2];
    return ((d000 * _fx + d001 * fx) * _fy + //
            (d010 * _fx + d011 * fx) * fy) * //
               _fz +
           ((d100 * _fx + d101 * fx) * _fy + //
            (d110 * _fx + d111 * fx) * fy) * //
               fz;
}