import torch
from pypsdr.largesteps_gpu.geometry import compute_matrix
from pypsdr.largesteps_gpu.optimize import AdamUniform
from pypsdr.largesteps_gpu.parameterize import from_differential, to_differential
from pypsdr.utils.timer import Timer


class LargeStepsGpu(torch.optim.Optimizer):
    def __init__(self, V, F, lr=0.1, betas=(0.9, 0.999), lmbda=0.1):
        self.V = V[0]
        self.F = F.cuda()
        self.M = compute_matrix(self.V.cuda(), self.F, lmbda)
        self.u = torch.tensor(to_differential(self.M, self.V.cuda().detach()),
                              requires_grad=True)
        defaults = dict(F=self.F, lr=lr, betas=betas)
        self.optimizer = AdamUniform([self.u], lr=lr, betas=betas)
        super(LargeStepsGpu, self).__init__(V, defaults)

    def step(self):
        with Timer("LargeSteps_gpu.step") as timer:
            # build compute graph from u to V
            V = from_differential(self.M, self.u, 'Cholesky')
            # propagate gradients from V to u
            V.backward(self.V.grad.cuda())
            # step u
            self.optimizer.step()
            # update param
            self.V.data.copy_(from_differential(self.M, self.u, 'Cholesky').cpu())

    def zero_grad(self):
        super(LargeStepsGpu, self).zero_grad()
        self.optimizer.zero_grad()
