Skip to content

Commit

Permalink
Add sparse QR solver
Browse files Browse the repository at this point in the history
  • Loading branch information
jaisw7 committed Mar 21, 2019
1 parent 1c58698 commit f161fcf
Showing 1 changed file with 68 additions and 3 deletions.
71 changes: 68 additions & 3 deletions dgfs1D/cusolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,17 @@ def __init__(self):
self.CUBLAS_OP_T = 1
self.CUBLAS_OP_C = 2

# cusolverCreate
# cusolverDnCreate
self.cusolverDnCreate = lib.cusolverDnCreate
self.cusolverDnCreate.argtypes = [POINTER(c_void_p)]
self.cusolverDnCreate.errcheck = self._errcheck

# cusolverDestroy
# cusolverDnDestroy
self.cusolverDnDestroy = lib.cusolverDnDestroy
self.cusolverDnDestroy.argtypes = [c_void_p]
self.cusolverDnDestroy.errcheck = self._errcheck

# cusolverSetStream
# cusolverDnSetStream
self.cusolverDnSetStream = lib.cusolverDnSetStream
self.cusolverDnSetStream.argtypes = [c_void_p, c_void_p]
self.cusolverDnSetStream.errcheck = self._errcheck
Expand Down Expand Up @@ -143,6 +143,40 @@ def __init__(self):
]
self.cusolverDnSgetrs.errcheck = self._errcheck

# sparse functions

# cusolverSpCreate
self.cusolverSpCreate = lib.cusolverSpCreate
self.cusolverSpCreate.argtypes = [POINTER(c_void_p)]
self.cusolverSpCreate.errcheck = self._errcheck

# cusolverSpDestroy
self.cusolverSpDestroy = lib.cusolverSpDestroy
self.cusolverSpDestroy.argtypes = [c_void_p]
self.cusolverSpDestroy.errcheck = self._errcheck

# cusolverSpSetStream
self.cusolverSpSetStream = lib.cusolverSpSetStream
self.cusolverSpSetStream.argtypes = [c_void_p, c_void_p]
self.cusolverSpSetStream.errcheck = self._errcheck

# cusolverSpScsrlsvlu
self.cusolverSpScsrlsvqr = lib.cusolverSpScsrlsvqr
self.cusolverSpScsrlsvqr.argtypes = [
c_void_p, c_int, c_int,
c_void_p, c_void_p, c_void_p, c_void_p,
c_void_p, c_float, c_int, c_void_p, POINTER(c_int)
]
self.cusolverSpScsrlsvqr.errcheck = self._errcheck

# cusolverSpDcsrlsvlu
self.cusolverSpDcsrlsvqr = lib.cusolverSpDcsrlsvqr
self.cusolverSpDcsrlsvqr.argtypes = [
c_void_p, c_int, c_int,
c_void_p, c_void_p, c_void_p, c_void_p,
c_void_p, c_double, c_int, c_void_p, POINTER(c_int)
]
self.cusolverSpDcsrlsvqr.errcheck = self._errcheck


def _errcheck(self, status, fn, args):
Expand Down Expand Up @@ -181,6 +215,9 @@ def __init__(self):
self._handle = c_void_p()
self._wrappers.cusolverDnCreate(self._handle)

self._handle_sp = c_void_p()
self._wrappers.cusolverSpCreate(self._handle_sp)

# Init CUBLAS
self._handle_blas = c_void_p()
self._wrappers_blas.cublasCreate(self._handle_blas)
Expand Down Expand Up @@ -256,3 +293,31 @@ def solveLU(self, A, sA, b, n):
cusolvergetrs(self._handle, w.CUBLAS_OP_T, n, 1, A.ptr, m, ipiv.ptr,
b.ptr, n, info.ptr)



# Solve Ax=b via LU decomposition: Faster than QRF for small matrices
def solveQRSparse(self, descrA, a, b, x, tol=1e-6, reorder=0):
# Wroks for square matrices only
w = self._wrappers

# Ensure the matrices are compatible
#if sA[1] != n:
# raise ValueError('Incompatible matrices for solve(A, b)')

# CUSOLVER expects inputs to be column-major (or Fortran order in
# numpy parlance).
m, n, nnz = a.shape

if a.csrVal.dtype == np.float64:
cusolverspcsrlsvqr = w.cusolverSpDcsrlsvqr
tol_ct = c_double(tol)
else:
cusolverspcsrlsvqr = w.cusolverSpScsrlsvqr
tol_ct = c_float(tol)

reorder_ct, ret = c_int(reorder), c_int(0)

cusolverspcsrlsvqr(self._handle_sp, n, nnz,
descrA, a.csrVal.ptr, a.csrRowPtr.ptr, a.csrColInd.ptr,
b.ptr, tol_ct, reorder_ct, x.ptr, ret)

0 comments on commit f161fcf

Please sign in to comment.