#pragma once
#include <core/fwd.h>
#include <core/utils.h>
#include <pybind11/numpy.h>
#include <pybind11/eigen.h>
namespace py = pybind11;

using ArrayXs = Eigen::Array<Spectrum, -1, -1, Eigen::RowMajor>;
// type caster for ArrayXs -> numpy.ndarray
namespace pybind11
{
    namespace detail
    {
        template <>
        struct type_caster<ArrayXs>
        {
            PYBIND11_TYPE_CASTER(ArrayXs, _("ArrayXs"));
            bool load(handle src, bool)
            {
                auto buf = array::ensure(src);
                py::buffer_info buffer_info = buf.request();
                if (buffer_info.ndim != 3)
                {
                    throw std::runtime_error("Image must be 3D");
                }
                if (buffer_info.shape[2] != 3)
                {
                    throw std::runtime_error("Image must have 3 channels");
                }
                // buf.shape
                value = ArrayXs::Zero(buffer_info.shape[0], buffer_info.shape[1]);
                std::copy((double *)buf.data(), (double *)buf.data() + buf.size(), (double *)value.data());
                return true;
            }

            static handle cast(ArrayXs src, return_value_policy /* policy */, handle /* parent */)
            {
                return array_t<double>({src.rows(),
                                        src.cols(),
                                        static_cast<py::ssize_t>(3)}, // shape

                                       {src.cols() * static_cast<py::ssize_t>(3) * sizeof(double),
                                        static_cast<py::ssize_t>(3) * sizeof(double),
                                        sizeof(double)}, // strides

                                       (double *)src.data())
                    .release();
            }
        };
    }
}

struct ImageBlock
{
    ImageBlock(const Array2i &offset, const Array2i &blockSize)
        : m_offset(offset), m_blockSize(blockSize)
    {
        m_pixelGenerated = 0;
        m_curPixel = Array2i(0, 0);
        m_pixelTotal = blockSize.prod();
        m_data = decltype(m_data)::Zero(blockSize.x(), blockSize.y());
    }

    ImageBlock(const Array2i &offset, const Array2i &blockSize, const ArrayXd &data);

    Array2i offset() const { return m_offset; }
    Array2i block_size() const { return m_blockSize; }

    Array2i curPixel()
    {
        return m_offset + m_curPixel;
    }
    Vector2i nextPixel()
    {
        if (!hasNext())
            assert(false);

        if (++m_curPixel.x() >= m_blockSize.x())
        {
            m_curPixel.y()++;
            m_curPixel.x() = 0;
        }
        m_pixelGenerated++;
        return curPixel();
    }

    bool hasNext()
    {
        return m_pixelGenerated < m_pixelTotal;
    }

    // get the pixel at the given position relative to the original image
    Spectrum get(const Array2i &pos);

    // accumulate another block into this block
    void put(const ImageBlock &block);
    // accumulate the pixel into this block
    // pos is the position of the pixel in the image not the position in the block
    void put(const Vector2i &pos, const Spectrum &value);

    ArrayXs &getData();
    
    const ArrayXs &getData() const;

    // void setData(py::array_t<Float> data);

    ArrayXd flattened() const;

    Array2i m_offset;
    Array2i m_blockSize;
    int m_pixelTotal;
    int m_pixelGenerated;
    ArrayXs m_data;     // image data
    Array2i m_curPixel; // pixel index inside the block
};

struct BlockedImage
{
    /**
     * @brief Construct a new Blocked Image object
     *
     * @param size
     *      size of the image
     * @param blockSize
     *      size of the square pixel block
     */
    BlockedImage(const Array2i &size, const Array2i &blockSize)
        : m_size(size),
          m_blockSize(blockSize)
    {
        m_curBlock = Vector2i(0, 0);
        m_numBlocks = ceil(size.cast<Float>() / blockSize.cast<Float>()).cast<int>();
        m_BlocksTotal = m_numBlocks.prod();
        m_BlocksGenerated = 0;
    }

    ImageBlock curBlock()
    {
        return ImageBlock(m_blockSize * m_curBlock,
                          m_blockSize.min(m_size - m_blockSize * m_curBlock));
    }

    ImageBlock nextBlock()
    {
        if (!hasNext())
            assert(false);

        if (++m_curBlock.x() >= m_numBlocks.x())
        {
            m_curBlock.y()++;
            m_curBlock.x() = 0;
        }
        m_BlocksGenerated++;
        return curBlock();
    }

    ImageBlock getBlock(int i)
    {
        Array2i idx = unravel_index(i, m_numBlocks);
        return ImageBlock(m_blockSize * idx,
                          m_blockSize.min(m_size - m_blockSize * idx));
    }

    bool hasNext()
    {
        return m_BlocksGenerated < m_BlocksTotal;
    }

    int size()
    {
        return m_BlocksTotal;
    }

    Array2i m_size; // image size
    Array2i m_blockSize;
    Array2i m_numBlocks;
    int m_BlocksTotal;
    int m_BlocksGenerated;
    Array2i m_curBlock;
};
