import vredner
import pyvredner
import torch
import sys

def loadShape(fname, grad):
    file = pyvredner.load_obj(fname)[0]

    vertices = file.vertices
    indices = file.indices
    uvs = file.uvs
    normals = file.normals
    assert vertices.shape == grad.shape

    return vredner.Shape(\
               vredner.float_ptr(vertices.data_ptr()),
               vredner.int_ptr(indices.data_ptr()),
               vredner.float_ptr(uvs.data_ptr() if uvs is not None else 0),
               vredner.float_ptr(normals.data_ptr() if normals is not None else 0),
               len(vertices),
               len(indices),
               -1, 0, -1, -1,
               vredner.float_ptr(grad.data_ptr())
           )

def addGrad(fname, value, end=[]):
    file = pyvredner.load_obj(fname)[0]
    grad = []
    if len(end) == 0:
        return torch.tensor([value for i in range(len(file.vertices))])
    else:
        index = 0
        for i in end:
            grad += [value[index] for j in range(i)]
            index += 1
    return torch.tensor(grad)

def main(prob_id):
    shapes = []
    
    if prob_id == 1:
        grad1 = addGrad("./obj/prob1_1.obj", [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], [3, 3, 3])
        shapes.append(loadShape("./obj/prob1_1.obj", grad1))
        nBatches = [5, 0, 5, 0]
        FD_delta = 1e-3
        vredner.path_test(shapes, nBatches, FD_delta, [], vredner.float_ptr(torch.tensor(0.0)))
    elif prob_id == 2:
        grad1 = addGrad("./obj/prob1_1.obj", [[1.0, 0.0, 0.0], [-1.0, 0.0, 0.0]], [3, 3])
        shapes.append(loadShape("./obj/prob2_1.obj", grad1))
        grad2 = addGrad("./obj/prob2_2.obj", [1.0, 1.0, 1.0])
        shapes.append(loadShape("./obj/prob2_2.obj", grad2))
        nBatches = [5, 5, 5, 0]
        FD_delta = 1e-2
        vredner.path_test(shapes, nBatches, FD_delta, [], vredner.float_ptr(0.0))
    elif prob_id == 3:
        grad1 = addGrad("./obj/prob2_1.obj", [0.0, 0.0, 0.0])
        shapes.append(loadShape("./obj/prob2_1.obj", grad1))
        grad2 = addGrad("./obj/dodecahedron.obj", [-1.0, 1.0, 1.0])
        shapes.append(loadShape("./obj/dodecahedron.obj", grad2))
        nBatches = [0, 0, 5, 0]
        FD_delta = 1e-2
        vredner.path_test(shapes, nBatches, FD_delta, [], vredner.float_ptr(0.0))
    elif prob_id == 4:
        grad1 = addGrad("./obj/prob2_1.obj", [0.0, 0.0, 0.0])
        shapes.append(loadShape("./obj/prob2_1.obj", grad1))
        grad2 = addGrad("./obj/bunny.obj", [-1.0, 1.0, 1.0])
        shapes.append(loadShape("./obj/bunny.obj", grad2))
        nBatches = [0, 0, 5, 0]
        FD_delta = 1e-2
        vredner.path_test(shapes, nBatches, FD_delta, [], vredner.float_ptr(0.0))
    elif prob_id == 6:
        grad1 = addGrad("./obj/prob2_1.obj", [0.0, 0.0, 0.0])
        shapes.append(loadShape("./obj/prob2_1.obj", grad1))
        grad2 = addGrad("./obj/dodecahedron.obj", [-1.0, 1.0, 1.0])
        shapes.append(loadShape("./obj/dodecahedron.obj", grad2))
        para = [100, 100, 100, 1000]
        grid_data = torch.zeros(para[0 : 3])
        vredner.compute_grid(shapes, para, vredner.float_ptr(grid_data.data_ptr()))
        print(grid_data.sum())
        torch.save(grid_data, "data.bin")
    elif prob_id == 7:
        grad1 = addGrad("./obj/prob2_1.obj", [0.0, 0.0, 0.0])
        shapes.append(loadShape("./obj/prob2_1.obj", grad1))
        grad2 = addGrad("./obj/dodecahedron.obj", [-1.0, 1.0, 1.0])
        shapes.append(loadShape("./obj/dodecahedron.obj", grad2))
        grid_data = torch.load("data.bin")
        nBatches = [0, 0, 5, 5]
        FD_delta = 1e-2
        vredner.path_test(shapes, nBatches, FD_delta, list(grid_data.shape), vredner.float_ptr(grid_data.data_ptr()))
    else:
        assert False, ("Unknown problem id: %d" % prob_id)

if __name__ == '__main__':
    if len(sys.argv) > 1:
        main(int(sys.argv[1]))
    else:
        print("Usage: python test.py [Prob ID]")
