Skip to content

Commit

Permalink
ENH: Add uarray support for special matrices and sketches
Browse files Browse the repository at this point in the history
  • Loading branch information
Smit-create committed Jan 19, 2022
1 parent eb2ad0a commit 815001e
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 8 deletions.
2 changes: 0 additions & 2 deletions scipy/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,8 @@
from ._multimethods import *
from .blas import *
from .lapack import *
from ._special_matrices import *
from ._procrustes import *
from ._decomp_update import *
from ._sketches import *
from ._decomp_cossin import *

# Deprecated namespaces, to be removed in v2.0.0
Expand Down
2 changes: 2 additions & 0 deletions scipy/linalg/_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ._solvers import *
from ._decomp import *
from ._matfuncs import *
from ._special_matrices import *
from ._sketches import *
6 changes: 4 additions & 2 deletions scipy/linalg/_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import scipy._lib.uarray as ua
from scipy.linalg import _api
from scipy.sparse import issparse
import numpy as np


__all__ = [
'register_backend', 'set_backend',
'set_global_backend', 'skip_backend'
Expand Down Expand Up @@ -31,8 +33,8 @@ def __ua_function__(method, args, kwargs):

@ua.wrap_single_convertor
def __ua_convert__(value, dispatch_type, coerce):
if value is None:
return None
if value is None or issparse(value):
return value

if dispatch_type is np.ndarray:
if not coerce and not isinstance(value, np.ndarray):
Expand Down
166 changes: 165 additions & 1 deletion scipy/linalg/_multimethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from scipy.linalg import _api
from scipy.linalg._backend import scalar_tuple_callable_array

# they don't need to be dispatchabled
from ._api import (hilbert, helmert, invhilbert, pascal, invpascal, dft,
block_diag)



__all__ = [
# solvers
'solve_sylvester',
Expand All @@ -18,7 +24,14 @@
'expm', 'cosm', 'sinm', 'tanm', 'coshm', 'sinhm',
'tanhm', 'logm', 'funm', 'signm', 'sqrtm',
'expm_frechet', 'expm_cond', 'fractional_matrix_power',
'khatri_rao'
'khatri_rao',
# sketches
'clarkson_woodruff_transform',
# special matrices
'tri', 'tril', 'triu', 'toeplitz', 'circulant', 'hankel',
'hadamard', 'leslie', 'kron', 'block_diag', 'companion',
'helmert', 'hilbert', 'invhilbert', 'pascal', 'invpascal', 'dft',
'fiedler', 'fiedler_companion', 'convolution_matrix'
]


Expand Down Expand Up @@ -342,3 +355,154 @@ def expm_frechet(A, E, method=None, compute_expm=True, check_finite=True):
@_get_docs
def khatri_rao(a, b):
return a, b

############################### sketches #######################################


def _inputmatrix_replacer(args, kwargs, dispatchables):
def self_method(input_matrix, *args, **kwargs):
return dispatchables + args, kwargs

return self_method(*args, **kwargs)


@_create_linalg(_inputmatrix_replacer)
@all_of_type(np.ndarray)
@_get_docs
def clarkson_woodruff_transform(input_matrix, sketch_size, seed=None):
return (input_matrix, )


############################### special matrices ###############################

def _N_M_k_dtype_replacer(args, kwargs, dispatchables):
def self_method(N, M=None, k=0, dtype=None, *args, **kwargs):
return (N, M, k, dispatchables[0]) + args, kwargs

return self_method(*args, **kwargs)


@_create_linalg(_N_M_k_dtype_replacer)
@all_of_type(np.dtype)
@_get_docs
def tri(N, M=None, k=0, dtype=None):
return (dtype, )


def _m_replacer(args, kwargs, dispatchables):
def self_method(m, *args, **kwargs):
return dispatchables + args, kwargs

return self_method(*args, **kwargs)


@_create_linalg(_m_replacer)
@all_of_type(np.ndarray)
@_get_docs
def tril(m, k=0):
return (m, )


@_create_linalg(_m_replacer)
@all_of_type(np.ndarray)
@_get_docs
def triu(m, k=0):
return (m, )


def _c_r_replacer(args, kwargs, dispatchables):
def self_method(c, r=None, *args, **kwargs):
return dispatchables + args, kwargs

return self_method(*args, **kwargs)


@_create_linalg(_c_r_replacer)
@all_of_type(np.ndarray)
@_get_docs
def toeplitz(c, r=None):
return c, r


@_create_linalg(_c_r_replacer)
@all_of_type(np.ndarray)
@_get_docs
def hankel(c, r=None):
return c, r


def _c_replacer(args, kwargs, dispatchables):
def self_method(c, *args, **kwargs):
return dispatchables + args, kwargs

return self_method(*args, **kwargs)


@_create_linalg(_c_replacer)
@all_of_type(np.ndarray)
@_get_docs
def circulant(c):
return (c, )


def _n_dtype_replacer(args, kwargs, dispatchables):
def self_method(n, dtype=int, *args, **kwargs):
return (n, dispatchables[0]) + args, kwargs

return self_method(*args, **kwargs)


@_create_linalg(_n_dtype_replacer)
@all_of_type(np.dtype)
@_get_docs
def hadamard(n, dtype=int):
return (dtype, )


def _f_s_replacer(args, kwargs, dispatchables):
def self_method(f, s, *args, **kwargs):
return dispatchables + args, kwargs

return self_method(*args, **kwargs)


@_create_linalg(_f_s_replacer)
@all_of_type(np.ndarray)
@_get_docs
def leslie(f, s):
return f, s


@_create_linalg(_a_b_replacer)
@all_of_type(np.ndarray)
@_get_docs
def kron(a, b):
return a, b


@_create_linalg(_a_replacer)
@all_of_type(np.ndarray)
@_get_docs
def companion(a):
return (a, )


@_create_linalg(_a_replacer)
@all_of_type(np.ndarray)
@_get_docs
def fiedler(a):
return (a, )


@_create_linalg(_a_replacer)
@all_of_type(np.ndarray)
@_get_docs
def fiedler_companion(a):
return (a, )


@_create_linalg(_a_replacer)
@all_of_type(np.ndarray)
@_get_docs
def convolution_matrix(a, n, mode='full'):
return (a, )
8 changes: 7 additions & 1 deletion scipy/linalg/tests/mock_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ def __call__(self, *args, **kwargs):
'expm', 'cosm', 'sinm', 'tanm', 'coshm', 'sinhm',
'tanhm', 'logm', 'funm', 'signm', 'sqrtm',
'expm_frechet', 'expm_cond', 'fractional_matrix_power',
'khatri_rao'
'khatri_rao',
# sketches
'clarkson_woodruff_transform',
# special matrices
'tri', 'tril', 'triu', 'toeplitz', 'circulant', 'hankel',
'hadamard', 'leslie', 'kron', 'companion',
'fiedler', 'fiedler_companion', 'convolution_matrix'
]

for name in method_names:
Expand Down
8 changes: 7 additions & 1 deletion scipy/linalg/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
'expm', 'cosm', 'sinm', 'tanm', 'coshm', 'sinhm',
'tanhm', 'logm', 'funm', 'signm', 'sqrtm',
'expm_frechet', 'expm_cond', 'fractional_matrix_power',
'khatri_rao'
'khatri_rao',
# sketches
'clarkson_woodruff_transform',
# special matrices
'tri', 'tril', 'triu', 'toeplitz', 'circulant', 'hankel',
'hadamard', 'leslie', 'kron', 'companion',
'fiedler', 'fiedler_companion', 'convolution_matrix'
]


Expand Down
2 changes: 1 addition & 1 deletion scipy/linalg/tests/test_matfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

import scipy.linalg
from scipy.linalg._matfuncs import (funm, signm, logm, sqrtm,
from scipy.linalg import (funm, signm, logm, sqrtm,
fractional_matrix_power,
expm, expm_frechet, expm_cond, norm, khatri_rao)
from scipy.linalg import _matfuncs_inv_ssq
Expand Down

0 comments on commit 815001e

Please sign in to comment.