import psdr_cuda
import cv2
import numpy as np
from .file import create_output_dir

def render_scene(init_path, scene_file, spp, npass):
    create_output_dir(init_path)
    create_output_dir(init_path+"exr/")
    create_output_dir(init_path+"png/")
    sc = psdr_cuda.Scene()
    sc.load_file(scene_file, False)
    ro = sc.opts
    ro.spp = spp
    ro.sppe = 0
    ro.sppse = 0
    sc.opts.log_level = 0
    num_sensors = sc.num_sensors
    sc.configure()
    integrator = psdr_cuda.DirectIntegrator(1, 1)
    integrator.hide_emitters = True
    for sensor_id in range(num_sensors):
        img = integrator.renderC(sc, sensor_id)
        for n in range(1, npass):
            img += integrator.renderC(sc, sensor_id)
        img /= npass
        img = img.numpy().reshape((ro.height, ro.width, 3))
        output = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(init_path+"exr/"+"sensor_"+str(sensor_id)+".exr", output)
        im = np.power(output, 1/2.2)
        im = np.uint8(np.clip(im * 255., 0., 255.))
        cv2.imwrite(init_path+"png/sensor_"+str(sensor_id)+".png", im)
        del img, im, output
        print("(%d/%d) cameras done." % (sensor_id, num_sensors), end="\r")

def render_scene_silhouette(init_path, scene_file, spp, npass):
    create_output_dir(init_path)
    create_output_dir(init_path+"exr/")
    create_output_dir(init_path+"png/")
    sc = psdr_cuda.Scene()
    sc.load_file(scene_file, False)
    ro = sc.opts
    ro.spp = spp
    ro.sppe = 0
    ro.sppse = 0
    sc.opts.log_level = 0
    num_sensors = sc.num_sensors
    sc.configure()
    integrator = psdr_cuda.FieldExtractionIntegrator("silhouette")
    for sensor_id in range(num_sensors):
        img = integrator.renderC(sc, sensor_id)
        for n in range(1, npass):
            img += integrator.renderC(sc, sensor_id)
        img /= npass
        img = img.numpy().reshape((ro.height, ro.width, 3))
        output = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
        cv2.imwrite(init_path+"exr/"+"sensor_"+str(sensor_id)+".exr", output)
        im = np.power(output, 1/2.2)
        im = np.uint8(np.clip(im * 255., 0., 255.))
        cv2.imwrite(init_path+"png/sensor_"+str(sensor_id)+".png", im)
        del img, im, output
        print("(%d/%d) cameras done." % (sensor_id, num_sensors), end="\r")
