from torch.autograd import Function
import numpy as np
import scipy.sparse as sp
import scipy.sparse as cps
import sksparse.cholmod as cholmod
import torch
import scipy
from scipy.sparse.linalg import spsolve

class Solver:
    """
    Sparse linear system solver base class.
    """
    def __init__(self, M):
        pass

    def solve(self, b, backward=False):
        """
        Solve the linear system.

        Parameters
        ----------
        b : torch.Tensor
            The right hand side of the system Lx=b
        backward : bool (optional)
            Whether this is the backward or forward solve
        """
        raise NotImplementedError()

class CholeskySolver(Solver):
    """
    Cholesky solver.

    Precomputes the Cholesky decomposition of the system matrix and solves the
    system by back-substitution.
    """
    def __init__(self, M):
        """
        Initialize the solver

        Parameters
        ----------
        M : torch.tensor
            The matrix to decompose. It is assumed to be symmetric positive definite.
        """
        # Convert L to a scipy sparse matrix for factorization
        M_cpu = M.tocsc()
        self.M_cpu = M_cpu
        factor = cholmod.cholesky(M_cpu, ordering_method='nesdis', mode='simplicial')
        L, P = factor.L(), factor.P()
        # Invert the permutation
        Pi = np.argsort(P).astype(np.int32)
        # Transfer to GPU as cupy arrays
        self.L = cps.csc_matrix(L.astype(np.float32))
        self.U = self.L.T
        self.P = np.array(P)
        self.Pi = np.array(Pi)

    def solve(self, b, backward=False):
        """
        Solve the sparse linear system.
        """
        x = spsolve(self.M_cpu, b.numpy())
        return torch.tensor(x).contiguous()

class DifferentiableSolve(Function):
    """
    Differentiable function to solve the linear system.

    This simply calls the solve methods implemented by the Solver classes.
    """
    @staticmethod
    def forward(ctx, solver, b):
        ctx.solver = solver
        return solver.solve(b, backward=False)

    @staticmethod
    def backward(ctx, grad_output):
        solver_grad = None # We have to return a gradient per input argument in forward
        b_grad = None
        if ctx.needs_input_grad[1]:
            b_grad = ctx.solver.solve(grad_output, backward=True)
        return (solver_grad, b_grad)

# Alias for DifferentiableSolve function
solve = DifferentiableSolve.apply
