#pragma once

#include <enoki/cuda.h>
#include <enoki/autodiff.h>
#include <enoki/random.h>
#include <chrono>
#include <iostream>
#include <array>
#include <random>

namespace psdr {
using namespace enoki;

using FloatC    = CUDAArray<float>;
using FloatD    = DiffArray<FloatC>;

template<size_t N>
using VectorND   = Array<FloatD, N>;

template<size_t M, size_t N>
using MatrixMND  = Array<VectorND<M>, N>;

template<size_t size_in, size_t size_out, size_t size_hidden, size_t num_hlayers>
class MLP {
private:
    template<size_t M, size_t N>
    VectorND<M> mvmul(const MatrixMND<M, N>& m, const VectorND<N>& v) const {
        VectorND<M> sum = m.coeff(0) * VectorND<M>::full_(v.coeff(0), 1);
        for (size_t i = 1; i < N; i++) {
            sum = fmadd(m.coeff(i), VectorND<M>::full_(v.coeff(i), 1), sum);
        }
        return sum;
    }

    template<size_t M, size_t N>
    void set_weight_(MatrixMND<M, N>& m, const FloatD& data, size_t& start) {
        for (size_t i = 0; i < N; i++) 
            for (size_t j = 0; j < M; j++) 
                m.coeff(i, j) = data.coeff(start + i * M + j);
        start += M * N;
    }

    template<size_t M, size_t N>
    void get_grad_(const MatrixMND<M, N>& m, std::vector<float>& grad_data) {
        MatrixMND<M, N> m_grad = gradient(m);
        size_t start = grad_data.size();
        grad_data.resize(start + M * N);
        for (size_t i = 0; i < N; i++) 
            for (size_t j = 0; j < M; j++)  
                grad_data[start + i * M + j] = m_grad.coeff(i, j, 0);
    }

    template<size_t M, size_t N>
    VectorND<M> eval_layer(const VectorND<N>& input, const MatrixMND<M, N+1>& weight) const {
        VectorND<N+1> input_;
        for (size_t i = 0; i < N; i++)
            input_.coeff(i) = input.coeff(i);
        input_.coeff(N) = zero<FloatD>(shape(input)[1]) + 1;
        VectorND<M> o = mvmul(weight, input_);
        o &= o >= 0.f;
        return o; 
    } 
public:
    MatrixMND<size_hidden, size_in + 1> w_in;
    MatrixMND<size_out, size_hidden + 1> w_out;
    std::array<MatrixMND<size_hidden, size_hidden + 1>, num_hlayers> w_hlayers;

    MLP() {}

    VectorND<size_out> eval(VectorND<size_in> input) const {
        VectorND<size_hidden> tmp = eval_layer(input, w_in);
        for (size_t i = 0; i < num_hlayers; i++) 
            tmp = eval_layer(tmp, w_hlayers[i]);
        VectorND<size_out> output = tanh(eval_layer(tmp, w_out));
        return output;
    }

    void set_weight(const FloatD& data, bool requires_grad) {
        size_t start = 0;
        set_weight_(w_in, data, start);
        for (size_t i = 0; i < num_hlayers; i++) 
            set_weight_(w_hlayers[i], data, start);
        set_weight_(w_out, data, start);
        if (requires_grad) enable_diff();
    }

    void enable_diff() {
        set_requires_gradient(w_in);
        for (size_t i = 0; i < num_hlayers; i++)
            set_requires_gradient(w_hlayers[i]);
        set_requires_gradient(w_out);
    }

    void get_grad(FloatD& grad) {
        cuda_eval(); cuda_sync();
        std::vector<float> grad_data;
        get_grad_(w_in, grad_data);
        for (size_t i = 0; i < num_hlayers; i++)
            get_grad_(w_hlayers[i], grad_data);
        get_grad_(w_out, grad_data);
        grad = FloatD::copy(grad_data.data(), grad_data.size());
    }
};

} // namespace psdr