import torch
from torch import Tensor
from typing import List
import numpy as np
from pypsdr.largesteps.geometry import compute_matrix

from pypsdr.largesteps.optimize import AdamUniform
from pypsdr.largesteps.parameterize import from_differential, to_differential
from pypsdr.utils.timer import Timer


def sparse_eye(size):
    indices = torch.arange(
        0, size).long().unsqueeze(0).expand(2, size)
    values = torch.tensor(1.0).expand(size)
    cls = getattr(torch.sparse, values.type().split(".")[-1])
    return cls(indices, values, torch.Size([size, size]))


def adamax(params: List[Tensor],
           grads: List[Tensor],
           m1_tp: List[Tensor],
           m2_tp: List[Tensor],
           state_steps: List[int],
           *,
           beta1: float,
           beta2: float,
           lr: float,
           IL_term, IL_solver):
    r"""Functional API that performs adamax algorithm computation.
    See :class:`~torch.optim.Adamax` for details.
    """
    for i, param in enumerate(params):
        grad = grads[i]
        m1_tp = m1_tp[i]
        m2_tp = m2_tp[i]
        step = state_steps[i]

        grad = torch.as_tensor(IL_solver(np.asarray(grad)))
        m1_tp.mul_(beta1).add_(grad, alpha=1 - beta1)
        m2_tp.mul_(beta2).add_(grad.square(), alpha=1 - beta2)
        u = torch.matmul(IL_term, param.detach())
        clr = lr / ((1-beta1 ** step) * (m2_tp.amax() /
                    (1-beta2 ** step)).sqrt()) * m1_tp
        u = u - clr
        param.copy_(torch.as_tensor(IL_solver(np.asarray(u))))


class LGDescent(torch.optim.Optimizer):
    """Take a coordinate descent step for a random parameter.
    And also, make every 100th step way bigger.
    """

    def __init__(self, params, IL_term, IL_solver, lr=2e-3, betas=(0.9, 0.999)):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(
                "Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas)
        self.IL_term = IL_term
        self.IL_solver = IL_solver
        super(LGDescent, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:

            params_with_grad = []
            grads = []
            m1_tp = []
            m2_tp = []
            state_steps = []

            beta1, beta2 = group['betas']
            lr = group['lr']

            for p in group['params']:
                if p.grad is None:
                    continue
                params_with_grad.append(p)
                if p.grad.is_sparse:
                    raise RuntimeError(
                        'Adamax does not support sparse gradients')
                grads.append(p.grad)

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    state['m1_tp'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)
                    state['m2_tp'] = torch.zeros_like(
                        p, memory_format=torch.preserve_format)

                m1_tp.append(state['m1_tp'])
                m2_tp.append(state['m2_tp'])

                state['step'] += 1
                state_steps.append(state['step'])

            adamax(params_with_grad,
                   grads,
                   m1_tp,
                   m2_tp,
                   state_steps,
                   beta1=beta1,
                   beta2=beta2,
                   lr=lr,
                   IL_term=self.IL_term, IL_solver=self.IL_solver)

        return loss


class LargeSteps(torch.optim.Optimizer):
    def __init__(self, V, F, lr=0.1, betas=(0.9, 0.999), lmbda=0.1):
        self.V = V[0]
        self.F = F
        self.M = compute_matrix(self.V, self.F, lmbda)
        self.u = torch.tensor(to_differential(self.M, self.V.detach()),
                              requires_grad=True)
        defaults = dict(F=self.F, lr=lr, betas=betas)
        self.optimizer = AdamUniform([self.u], lr=lr, betas=betas)
        super(LargeSteps, self).__init__(V, defaults)

    def step(self):
        with Timer("LargeSteps.step") as timer:
            # build compute graph from u to V
            V = from_differential(self.M, self.u, 'Cholesky')
            # propagate gradients from V to u
            V.backward(self.V.grad)
            # step u
            self.optimizer.step()
            # update param
            self.V.data.copy_(from_differential(self.M, self.u, 'Cholesky'))

    def zero_grad(self):
        super(LargeSteps, self).zero_grad()
        self.optimizer.zero_grad()

class LargeSteps2(torch.optim.Optimizer):
    def __init__(self, V, F, lr=0.1, betas=(0.9, 0.999), lmbda=0.1):
        self.V = V[0]
        self.F = F
        self.M = compute_matrix(self.V, self.F, lmbda)
        self.u = torch.tensor(to_differential(self.M, self.V.detach()[:,1]),
                              requires_grad=True)
        defaults = dict(F=self.F, lr=lr, betas=betas)
        self.optimizer = AdamUniform([self.u], lr=lr, betas=betas)
        super(LargeSteps2, self).__init__(V, defaults)

    def step(self):
        with Timer("LargeSteps.step") as timer:
            # build compute graph from u to V
            V = from_differential(self.M, self.u, 'Cholesky')
            # propagate gradients from V to u
            V.backward(self.V.grad[:,1])
            # step u
            self.optimizer.step()
            # update param
            V = from_differential(self.M, self.u, 'Cholesky')
            self.V.data[:,1] = V.detach()
            # self.V.data.copy_(V.detach())

    def zero_grad(self):
        super(LargeSteps2, self).zero_grad()
        self.optimizer.zero_grad()