import torch, time
import enoki
from enoki.cuda_autodiff import Float32 as FloatD
from enoki.cuda_autodiff import Vector3f as Vector3fD
from enoki.cuda_autodiff import Vector3i as Vector3iD
from enoki.cuda import Vector3f as Vector3fC
from utils import print_gpu_usage

class RenderFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, verts: torch.Tensor, scene, integrator, npass, sensor_id, quiet):
        # Convert from PyTorch to Enoki (Vector3fX)
        ctx.verts = Vector3fD(verts.cuda())
        enoki.set_requires_gradient(ctx.verts, True)
        scene.param_map["Mesh[0]"].vertex_positions = ctx.verts        
        scene.configure2([sensor_id])

        # render the target image
        t0 = time.process_time()
        t1 = t0
        for i in range(npass):
            #print_gpu_usage("sensor = {0}, spp = {1}, ipass={2}".format(sensor_id, scene.opts.spp, i))
            if i == 0:
                img_ad = integrator.renderD(scene, sensor_id) / npass
            else:
                img_ad += integrator.renderD(scene, sensor_id) / npass
            t2 = time.process_time()
            if t2 - t1 > 0.2 and not quiet:
                print("(%d/%d) done in %.2f seconds." % (i + 1, npass, t2 - t0), end="\r")
                t1 = t2
        t2 = time.process_time()
        if not quiet:
            print("(%d/%d) Total deriv. rendering time: %.2f seconds." % (npass, npass, t2 - t0))

        ctx.out = img_ad

        img = ctx.out.torch().to(verts.device)

        enoki.cuda_malloc_trim()
        return img

    @staticmethod
    def backward(ctx, grad_out):
        enoki.set_gradient(ctx.out, Vector3fC(grad_out.reshape(-1, 3).cuda()))
        # static version
        FloatD.backward()
        verts_grad = torch.nan_to_num(enoki.gradient(ctx.verts).torch().to(grad_out.device))

        grad_in = (
            verts_grad,
            None,
            None,
            None,
            None,
            None
        )
        # garbage collection
        del ctx.verts, ctx.out
        enoki.cuda_malloc_trim()
        return grad_in

class RenderLayer(torch.nn.Module):
    def __init__(self, scene, integrator, num_passes, sensor_indices, quiet=False):
        super().__init__()
        self.scene = scene
        self.integrator = integrator
        self.num_passes = num_passes
        self.sensor_indices = sensor_indices
        self.quiet = quiet

    def forward(self, verts: torch.Tensor) -> torch.Tensor:
        imgs = []
        for sensor_id in self.sensor_indices:
            imgs.append(RenderFunction.apply(verts, self.scene, self.integrator, self.num_passes, sensor_id, self.quiet))
        return imgs


class RenderFunctionNRF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, vtx_bsdf_params: torch.Tensor, scene, integrator, npass, sensor_id, quiet):
        # Convert from PyTorch to Enoki (Vector3fX)
        ctx.diffuse = Vector3fD(vtx_bsdf_params[:, :3])
        ctx.specular = Vector3fD(vtx_bsdf_params[:, 3:6])
        ctx.roughness = FloatD(vtx_bsdf_params[:, 6])
        enoki.set_requires_gradient(ctx.diffuse, True)
        enoki.set_requires_gradient(ctx.specular, True)
        enoki.set_requires_gradient(ctx.roughness, True)
        scene.param_map["Mesh[0]"].vertex_diffuse = ctx.diffuse        
        scene.param_map["Mesh[0]"].vertex_specular = ctx.specular        
        scene.param_map["Mesh[0]"].vertex_roughness = ctx.roughness      

        scene.configure2([sensor_id])

        # render the target image
        t0 = time.process_time()
        t1 = t0
        for i in range(npass):
            #print_gpu_usage("sensor = {0}, spp = {1}, ipass={2}".format(sensor_id, scene.opts.spp, i))
            if i == 0:
                img_ad = integrator.renderD(scene, sensor_id) / npass
            else:
                img_ad += integrator.renderD(scene, sensor_id) / npass
            t2 = time.process_time()
            if t2 - t1 > 0.2 and not quiet:
                print("(%d/%d) done in %.2f seconds." % (i + 1, npass, t2 - t0), end="\r")
                t1 = t2
        t2 = time.process_time()
        if not quiet:
            print("(%d/%d) Total deriv. rendering time: %.2f seconds." % (npass, npass, t2 - t0))

        ctx.out = img_ad

        img = ctx.out.torch().cuda()

        enoki.cuda_malloc_trim()
        return img

    @staticmethod
    def backward(ctx, grad_out):
        enoki.set_gradient(ctx.out, Vector3fC(grad_out.reshape(-1, 3).cuda()))
        # static version
        FloatD.backward()
        diffuse_grad = enoki.gradient(ctx.diffuse).torch().to(grad_out.device)
        specular_grad = enoki.gradient(ctx.specular).torch().to(grad_out.device)
        roughness_grad = enoki.gradient(ctx.roughness).torch().to(grad_out.device).unsqueeze(1)
        vtx_bsdf_params_grad = torch.cat([diffuse_grad, specular_grad, roughness_grad], dim=1)

        grad_in = (
            vtx_bsdf_params_grad,
            None,
            None,
            None,
            None,
            None
        )
        # garbage collection
        del ctx.diffuse, ctx.specular, ctx.roughness, ctx.out
        enoki.cuda_malloc_trim()
        return grad_in

class RenderLayerNRF(torch.nn.Module):
    def __init__(self, scene, integrator, num_passes, sensor_indices, quiet=False):
        super().__init__()
        self.scene = scene
        self.integrator = integrator
        self.num_passes = num_passes
        self.sensor_indices = sensor_indices
        self.quiet = quiet

    def forward(self, vtx_bsdf_params: torch.Tensor) -> torch.Tensor:
        imgs = []
        for sensor_id in self.sensor_indices:
            imgs.append(RenderFunctionNRF.apply(vtx_bsdf_params, self.scene, self.integrator, self.num_passes, sensor_id, self.quiet))
        return imgs

class RenderFunction_SDF_NRF(torch.autograd.Function):
    @staticmethod
    def forward(ctx, vtx: torch.Tensor, vtx_bsdf_params: torch.Tensor, scene, integrator, npass, sensor_id, quiet):
        # Convert from PyTorch to Enoki (Vector3fX)
        ctx.vtx = Vector3fD(vtx.cuda())
        enoki.set_requires_gradient(ctx.vtx, True)
        scene.param_map["Mesh[0]"].vertex_positions = ctx.vtx        

        ctx.diffuse = Vector3fD(vtx_bsdf_params[:, :3])
        ctx.specular = Vector3fD(vtx_bsdf_params[:, 3:6])
        ctx.roughness = FloatD(vtx_bsdf_params[:, 6])
        enoki.set_requires_gradient(ctx.diffuse, True)
        enoki.set_requires_gradient(ctx.specular, True)
        enoki.set_requires_gradient(ctx.roughness, True)
        scene.param_map["Mesh[0]"].vertex_diffuse = ctx.diffuse        
        scene.param_map["Mesh[0]"].vertex_specular = ctx.specular        
        scene.param_map["Mesh[0]"].vertex_roughness = ctx.roughness

        scene.configure2([sensor_id])

        # render the target image
        t0 = time.process_time()
        t1 = t0
        for i in range(npass):
            #print_gpu_usage("sensor = {0}, spp = {1}, ipass={2}".format(sensor_id, scene.opts.spp, i))
            if i == 0:
                img_ad = integrator.renderD(scene, sensor_id) / npass
            else:
                img_ad += integrator.renderD(scene, sensor_id) / npass
            t2 = time.process_time()
            if t2 - t1 > 0.2 and not quiet:
                print("(%d/%d) done in %.2f seconds." % (i + 1, npass, t2 - t0), end="\r")
                t1 = t2
        t2 = time.process_time()
        if not quiet:
            print("(%d/%d) Total deriv. rendering time: %.2f seconds." % (npass, npass, t2 - t0))

        ctx.out = img_ad

        img = ctx.out.torch().cuda()

        enoki.cuda_malloc_trim()
        return img

    @staticmethod
    def backward(ctx, grad_out):
        enoki.set_gradient(ctx.out, Vector3fC(grad_out.reshape(-1, 3).cuda()))
        # static version
        FloatD.backward()
        vtx_grad = enoki.gradient(ctx.vtx).torch().cuda()
        vtx_grad = torch.nan_to_num(vtx_grad)

        diffuse_grad = enoki.gradient(ctx.diffuse).torch().to(grad_out.device)
        specular_grad = enoki.gradient(ctx.specular).torch().to(grad_out.device)
        roughness_grad = enoki.gradient(ctx.roughness).torch().to(grad_out.device).unsqueeze(1)
        vtx_bsdf_params_grad = torch.cat([diffuse_grad, specular_grad, roughness_grad], dim=1)
        vtx_bsdf_params_grad = torch.nan_to_num(vtx_bsdf_params_grad)
        grad_in = (
            vtx_grad,
            vtx_bsdf_params_grad,
            None,
            None,
            None,
            None,
            None
        )
        # garbage collection
        del ctx.vtx, ctx.diffuse, ctx.specular, ctx.roughness, ctx.out
        enoki.cuda_malloc_trim()
        return grad_in

class RenderLayer_SDF_NRF(torch.nn.Module):
    def __init__(self, scene, integrator, num_passes, sensor_indices, quiet=False):
        super().__init__()
        self.scene = scene
        self.integrator = integrator
        self.num_passes = num_passes
        self.sensor_indices = sensor_indices
        self.quiet = quiet

    def forward(self, vtx: torch.Tensor, vtx_bsdf_params: torch.Tensor) -> torch.Tensor:
        imgs = []
        for sensor_id in self.sensor_indices:
            imgs.append(RenderFunction_SDF_NRF.apply(vtx, vtx_bsdf_params, self.scene, self.integrator, self.num_passes, sensor_id, self.quiet))
        return imgs