import pytest
import torch
import numpy as np
import psdr_cpu as psdr
from tests.psdr import *
from tests.integrator.config import *
from tests.integrator.integrator import *
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

file = "../../data/scenes/volcube/scene.xml"
file_thick = "../../data/scenes/volcube/thick.xml"
# file = "../../data/scenes/two_triangles/scene.xml"
integrator = psdr.Volpath()


@pytest.fixture
def sampler():
    return psdr.RndSampler(0, 2)


@pytest.fixture
def set_xform(scene):
    shape = scene.shapes[0]
    shape.requires_grad = True
    shape.setTranslation([1., 0., 0.])


# generate a camera ray
@pytest.fixture
def ray(camera, pixel_idx):
    return camera.samplePrimaryRay(pixel_idx[0], pixel_idx[1])


@pytest.fixture
def rRec(pixel_idx, sampler, max_bounces):
    return psdr.RadianceQueryRecord(pixel_idx, sampler, max_bounces)


@pytest.fixture
def pRec(pixel_idx, sampler, max_bounces, nsamples):
    return psdr.PixelQueryRecord(pixel_idx, sampler, max_bounces, nsamples, True)


@pytest.mark.parametrize('file', [file])
@pytest.mark.parametrize('max_bounces', [1])
@pytest.mark.parametrize('pixel_idx', [[164, 172]])
@pytest.mark.parametrize('integrator', [integrator])
class TestVolpath:
    def test_Li(self, Li):
        print(Li)
    
    def test_LiFwd(self, integrator, sceneAD, rRec, ray):
        sceneAD.getDer().shapes[0].param = 1
        grad = integrator.LiFwd(sceneAD, ray, rRec)
        print(grad)

    def test_LiAD(self, integrator, set_xform, sceneAD, rRec, ray):
        d_value = np.array([1., 0., 0.])
        sceneAD.getDer().shapes[0].param = 0
        integrator.LiAD(sceneAD, ray, rRec, d_value)
        print(sceneAD.getDer().shapes[0].param)
        print("finished")

    def test_LiAD_2(self, integrator, set_xform, sceneAD, rRec, ray):
        d_value = np.array([1., 0., 0.])
        sceneAD.getDer().shapes[0].param = 0
        integrator.LiAD(sceneAD, ray, rRec, d_value)
        print(sceneAD.getDer().shapes[0].param)
        print("finished")

    @pytest.mark.parametrize('nsamples', [100])
    def test_pixelColor(self, pixelColor):
        print(pixelColor)

    @pytest.mark.parametrize('nsamples', [100])
    def test_pixelColorAD(self, set_xform, integrator, sceneAD, pRec):
        d_value = np.array([1., 0., 0.])
        sceneAD.getDer().shapes[0].param = 0
        integrator.pixelColorAD(sceneAD, pRec, d_value)
        print(sceneAD.getDer().shapes[0].param)
        print("finish")
    
options = RenderOptions(0, 20, 1000, 0, 0, 0, 0)

@pytest.mark.parametrize('file', [file_thick])
@pytest.mark.parametrize('options', [options])
@pytest.mark.parametrize('integrator', [integrator])
def test_renderC(integrator, scene, options, height, width):
    image = integrator.renderC(scene, options)
    image = image.reshape(height, width, 3)
    imwrite(image, "forward.exr")
    print("finished")


@pytest.mark.parametrize('file', [file])
@pytest.mark.parametrize('options', [options])
@pytest.mark.parametrize('integrator', [integrator])
def test_renderD(integrator, set_xform, sceneAD, options, d_image, height, width):
    # psdr.set_forward(True)
    grad_image = integrator.renderD(sceneAD, options, d_image)
    grad_image = grad_image.reshape(height, width, 3)
    imwrite(grad_image, "backward.exr")
    print("finished")


if __name__ == "__main__":
    # pytest.main(["-x", "./test_integrator.py::TestVolpath::test_Li", "-s"])
    # pytest.main(["-x", "./test_integrator.py::TestVolpath::test_pixelColor", "-s"])
    # pytest.main(["-x", "./test_integrator.py::TestVolpath::test_LiAD", "-s"])
    # pytest.main(["-x", "./test_integrator.py::TestVolpath::test_pixelColorAD", "-s"])
    # pytest.main(["-x", "./test_integrator.py::TestVolpath::test_pixelColorAD", "-s"])
    pytest.main(["-x", "./test_integrator.py::test_renderC", "-s"])
    # pytest.main(["-x", "./test_integrator.py::test_renderD", "-s"])
    # pytest.main(["-x", "./test_integrator.py::TestVolpath::test_LiFwd", "-s"])
