import sys
from opt import optimize
import torch

config = {
    'chair_env': {
        'scene_dir': 'chair_env',
        'max_iter': 5000, 
        'resolution': 128, 
        'batch_size': 5, 
        'sdf_lr': 5e-5, 
        'sdf_sched_milestones': [100, 250, 500, 1000, 1500, 2000, 2500],
        'sdf_sched_factor': 0.5,
        'nrf_lr': 5e-5, 
        'nrf_sched_milestones': [250, 500, 1000, 1500, 2500],
        'nrf_sched_factor': 0.75,
        'checkpoint_iter': 5,
        'vis_sensor_indices': [22, 122, 76, 176],
        'render_spec': {
            'integrator': 'direct',
            'sppse_mode': 0,
            'spp': 16,
            'sppe': 4,
            'sppse': 1,
            'log_level': 0,
            'npass': 4
        },
        'img_weight': 0.1,
        'bsdf_param_min': torch.tensor([0, 0, 0, 0, 0, 0, 0.1]).float().cuda(),
        # 'bsdf_param_max': torch.tensor([1, 1, 1, 1, 1, 1, 0.6]).float().cuda(),
    },

    'pegasus_env': {
        'scene_dir': 'pegasus_env',
        'max_iter': 5000, 
        'resolution': 96, 
        'batch_size': 5, 
        'sdf_lr': 5e-5, 
        'sdf_sched_milestones': [100, 250, 1000, 2000, 5000],
        'sdf_sched_factor': 0.8,
        'nrf_lr': 1e-4, 
        'nrf_sched_milestones': [250, 1000, 2500],
        'nrf_sched_factor': 0.75,
        'checkpoint_iter': 5,
        'vis_sensor_indices': [277, 205, 32, 90],
        'bsdf_param_max': torch.tensor([1, 1, 1, 1, 1, 1, 0.5]).float().cuda(),
        'render_spec': {
            'integrator': 'direct',
            'sppse_mode': 0,
            'spp': 16,
            'sppe': 2,
            'sppse': 1,
            'log_level': 0,
            'npass': 2,
        },
        'img_weight': 0.1,
        'sil_weight': 5000,
    }
}

if __name__ == '__main__':
    from_checkpoint = -1 if len(sys.argv) == 2 else int(sys.argv[2])
    optimize(**config[sys.argv[1]], from_checkpoint=from_checkpoint) 