import pytest
import torch
import numpy as np
import psdr_cpu as psdr
from tests.psdr import *
from tests.integrator.config import *
from pypsdr.utils.io import mkdir
from pypsdr.validate import RenderOptions
import pypsdr
import os

pixel_idx = [156, 261]
sampler = psdr.RndSampler(0, 0)
file = "/home/zih/Projects/psdvr/psdr-enzyme/data/scenes/two_triangles/scene.xml"


@pytest.fixture
def ray(camera, pixel_idx):
    return camera.samplePrimaryRay(pixel_idx[0], pixel_idx[1])


@pytest.fixture
def sampler():
    return psdr.RndSampler(0, 0)


@pytest.fixture
def rRec(pixel_idx, sampler, max_bounces):
    return psdr.RadianceQueryRecord(pixel_idx, sampler, max_bounces)


@pytest.fixture
def lightPath(scene, ray, rRec):
    from psdr_cpu.algorithm1_vol import eval
    path = psdr.LightPath()
    value1 = psdr.volpath_meta.__Li(scene, ray, rRec, path)
    value2 = eval(scene, path, rRec.sampler)
    return path


@pytest.fixture
def lightPathAD(lightPath):
    return psdr.LightPathAD(lightPath)


@pytest.mark.parametrize("file", [file])
@pytest.mark.parametrize("pixel_idx", [pixel_idx])
@pytest.mark.parametrize("max_bounces", [1])
class TestEvalPath:
    def test_evalPathAD(self, sceneAD, lightPathAD):
        from psdr_cpu.algorithm1_vol import d_evalPath
        path = lightPathAD.val
        d_path = lightPathAD.der
        scene = sceneAD.val
        d_scene = sceneAD.der
        d_evalPath(scene, d_scene, path, d_path, [1, 0, 0])
        print(d_path.vertices[1].nee_bsdf)

    def test_evalPathFwd(self, sceneAD, lightPathAD):
        from psdr_cpu.algorithm1_vol import evalPathFwd
        scene = sceneAD.val
        d_scene = sceneAD.der
        path = lightPathAD.val
        d_path = lightPathAD.der
        d_path.vertices[1].nee_bsdf = np.array([1, 0, 0])
        print(d_path.vertices[1].nee_bsdf)
        d_value = evalPathFwd(scene, d_scene, path, d_path)
        print(d_value[1])

    def test_evalAD(self, sceneAD, lightPathAD, sampler):
        from psdr_cpu.algorithm1_vol import d_evalPath, d_evalVertex
        scene = sceneAD.val
        d_scene = sceneAD.der
        path = lightPathAD.val
        d_path = lightPathAD.der
        d_evalPath(scene, d_scene, path, d_path, [1, 0, 0])
        d_evalVertex(scene, d_scene, path, d_path, sampler)
        # d_eval(scene, d_scene, lightPathAD, [1, 0, 0], sampler)
        print(d_path.vertices[1].nee_bsdf)
    
    def test_evalFwd(self, sceneAD, lightPathAD, sampler):
        from psdr_cpu.algorithm1_vol import baselineFwd, evalFwd, getPathFwd, evalPathFwd, evalVertexFwd
        scene = sceneAD.val
        d_scene = sceneAD.der
        path = lightPathAD.val
        d_path = lightPathAD.der
        # baselineFwd(scene, d_scene, lightPathAD, sampler)
        # evalFwd(scene, d_scene, lightPathAD, sampler)
        # d_path.vertices[1].p = np.array([1, 0, 0])
        # getPathFwd(scene, d_scene, path, d_path)
        evalVertexFwd(scene, d_scene, path, d_path, sampler)
        # d_value = evalPathFwd(scene, d_scene, path, d_path)
        print(d_value)


if __name__ == "__main__":
    # pytest.main(
    #     ["-x", "./test_algorithm1.py::TestEvalPath::test_evalPathAD", "-s"])
    # pytest.main(
    #     ["-x", "./test_algorithm1.py::TestEvalPath::test_evalPathFwd", "-s"])
    # pytest.main(
    #     ["-x", "./test_algorithm1.py::TestEvalPath::test_evalAD", "-s"])
    pytest.main(
        ["-x", "./test_algorithm1.py::TestEvalPath::test_evalFwd", "-s"])
