From 815001e7fcdcba11a060d53b0b61794d71a4f31c Mon Sep 17 00:00:00 2001 From: Smit-create Date: Wed, 19 Jan 2022 11:41:03 +0530 Subject: [PATCH] ENH: Add uarray support for special matrices and sketches --- scipy/linalg/__init__.py | 2 - scipy/linalg/_api.py | 2 + scipy/linalg/_backend.py | 6 +- scipy/linalg/_multimethods.py | 166 +++++++++++++++++++++++++++- scipy/linalg/tests/mock_backend.py | 8 +- scipy/linalg/tests/test_backend.py | 8 +- scipy/linalg/tests/test_matfuncs.py | 2 +- 7 files changed, 186 insertions(+), 8 deletions(-) diff --git a/scipy/linalg/__init__.py b/scipy/linalg/__init__.py index 3a6e2e47f774..e65e4b2fce48 100644 --- a/scipy/linalg/__init__.py +++ b/scipy/linalg/__init__.py @@ -210,10 +210,8 @@ from ._multimethods import * from .blas import * from .lapack import * -from ._special_matrices import * from ._procrustes import * from ._decomp_update import * -from ._sketches import * from ._decomp_cossin import * # Deprecated namespaces, to be removed in v2.0.0 diff --git a/scipy/linalg/_api.py b/scipy/linalg/_api.py index 43dcebe0d5ad..3a6f104a928d 100644 --- a/scipy/linalg/_api.py +++ b/scipy/linalg/_api.py @@ -1,3 +1,5 @@ from ._solvers import * from ._decomp import * from ._matfuncs import * +from ._special_matrices import * +from ._sketches import * diff --git a/scipy/linalg/_backend.py b/scipy/linalg/_backend.py index 25159990ae3b..19464fd379f6 100644 --- a/scipy/linalg/_backend.py +++ b/scipy/linalg/_backend.py @@ -1,7 +1,9 @@ import scipy._lib.uarray as ua from scipy.linalg import _api +from scipy.sparse import issparse import numpy as np + __all__ = [ 'register_backend', 'set_backend', 'set_global_backend', 'skip_backend' @@ -31,8 +33,8 @@ def __ua_function__(method, args, kwargs): @ua.wrap_single_convertor def __ua_convert__(value, dispatch_type, coerce): - if value is None: - return None + if value is None or issparse(value): + return value if dispatch_type is np.ndarray: if not coerce and not isinstance(value, np.ndarray): diff --git a/scipy/linalg/_multimethods.py b/scipy/linalg/_multimethods.py index 9c021b7aee19..61dce16835df 100644 --- a/scipy/linalg/_multimethods.py +++ b/scipy/linalg/_multimethods.py @@ -4,6 +4,12 @@ from scipy.linalg import _api from scipy.linalg._backend import scalar_tuple_callable_array +# they don't need to be dispatchabled +from ._api import (hilbert, helmert, invhilbert, pascal, invpascal, dft, + block_diag) + + + __all__ = [ # solvers 'solve_sylvester', @@ -18,7 +24,14 @@ 'expm', 'cosm', 'sinm', 'tanm', 'coshm', 'sinhm', 'tanhm', 'logm', 'funm', 'signm', 'sqrtm', 'expm_frechet', 'expm_cond', 'fractional_matrix_power', - 'khatri_rao' + 'khatri_rao', + # sketches + 'clarkson_woodruff_transform', + # special matrices + 'tri', 'tril', 'triu', 'toeplitz', 'circulant', 'hankel', + 'hadamard', 'leslie', 'kron', 'block_diag', 'companion', + 'helmert', 'hilbert', 'invhilbert', 'pascal', 'invpascal', 'dft', + 'fiedler', 'fiedler_companion', 'convolution_matrix' ] @@ -342,3 +355,154 @@ def expm_frechet(A, E, method=None, compute_expm=True, check_finite=True): @_get_docs def khatri_rao(a, b): return a, b + +############################### sketches ####################################### + + +def _inputmatrix_replacer(args, kwargs, dispatchables): + def self_method(input_matrix, *args, **kwargs): + return dispatchables + args, kwargs + + return self_method(*args, **kwargs) + + +@_create_linalg(_inputmatrix_replacer) +@all_of_type(np.ndarray) +@_get_docs +def clarkson_woodruff_transform(input_matrix, sketch_size, seed=None): + return (input_matrix, ) + + +############################### special matrices ############################### + +def _N_M_k_dtype_replacer(args, kwargs, dispatchables): + def self_method(N, M=None, k=0, dtype=None, *args, **kwargs): + return (N, M, k, dispatchables[0]) + args, kwargs + + return self_method(*args, **kwargs) + + +@_create_linalg(_N_M_k_dtype_replacer) +@all_of_type(np.dtype) +@_get_docs +def tri(N, M=None, k=0, dtype=None): + return (dtype, ) + + +def _m_replacer(args, kwargs, dispatchables): + def self_method(m, *args, **kwargs): + return dispatchables + args, kwargs + + return self_method(*args, **kwargs) + + +@_create_linalg(_m_replacer) +@all_of_type(np.ndarray) +@_get_docs +def tril(m, k=0): + return (m, ) + + +@_create_linalg(_m_replacer) +@all_of_type(np.ndarray) +@_get_docs +def triu(m, k=0): + return (m, ) + + +def _c_r_replacer(args, kwargs, dispatchables): + def self_method(c, r=None, *args, **kwargs): + return dispatchables + args, kwargs + + return self_method(*args, **kwargs) + + +@_create_linalg(_c_r_replacer) +@all_of_type(np.ndarray) +@_get_docs +def toeplitz(c, r=None): + return c, r + + +@_create_linalg(_c_r_replacer) +@all_of_type(np.ndarray) +@_get_docs +def hankel(c, r=None): + return c, r + + +def _c_replacer(args, kwargs, dispatchables): + def self_method(c, *args, **kwargs): + return dispatchables + args, kwargs + + return self_method(*args, **kwargs) + + +@_create_linalg(_c_replacer) +@all_of_type(np.ndarray) +@_get_docs +def circulant(c): + return (c, ) + + +def _n_dtype_replacer(args, kwargs, dispatchables): + def self_method(n, dtype=int, *args, **kwargs): + return (n, dispatchables[0]) + args, kwargs + + return self_method(*args, **kwargs) + + +@_create_linalg(_n_dtype_replacer) +@all_of_type(np.dtype) +@_get_docs +def hadamard(n, dtype=int): + return (dtype, ) + + +def _f_s_replacer(args, kwargs, dispatchables): + def self_method(f, s, *args, **kwargs): + return dispatchables + args, kwargs + + return self_method(*args, **kwargs) + + +@_create_linalg(_f_s_replacer) +@all_of_type(np.ndarray) +@_get_docs +def leslie(f, s): + return f, s + + +@_create_linalg(_a_b_replacer) +@all_of_type(np.ndarray) +@_get_docs +def kron(a, b): + return a, b + + +@_create_linalg(_a_replacer) +@all_of_type(np.ndarray) +@_get_docs +def companion(a): + return (a, ) + + +@_create_linalg(_a_replacer) +@all_of_type(np.ndarray) +@_get_docs +def fiedler(a): + return (a, ) + + +@_create_linalg(_a_replacer) +@all_of_type(np.ndarray) +@_get_docs +def fiedler_companion(a): + return (a, ) + + +@_create_linalg(_a_replacer) +@all_of_type(np.ndarray) +@_get_docs +def convolution_matrix(a, n, mode='full'): + return (a, ) diff --git a/scipy/linalg/tests/mock_backend.py b/scipy/linalg/tests/mock_backend.py index 61618deb41f6..dce97218e146 100644 --- a/scipy/linalg/tests/mock_backend.py +++ b/scipy/linalg/tests/mock_backend.py @@ -27,7 +27,13 @@ def __call__(self, *args, **kwargs): 'expm', 'cosm', 'sinm', 'tanm', 'coshm', 'sinhm', 'tanhm', 'logm', 'funm', 'signm', 'sqrtm', 'expm_frechet', 'expm_cond', 'fractional_matrix_power', - 'khatri_rao' + 'khatri_rao', + # sketches + 'clarkson_woodruff_transform', + # special matrices + 'tri', 'tril', 'triu', 'toeplitz', 'circulant', 'hankel', + 'hadamard', 'leslie', 'kron', 'companion', + 'fiedler', 'fiedler_companion', 'convolution_matrix' ] for name in method_names: diff --git a/scipy/linalg/tests/test_backend.py b/scipy/linalg/tests/test_backend.py index f567e4b3d74b..dfd98b4d3314 100644 --- a/scipy/linalg/tests/test_backend.py +++ b/scipy/linalg/tests/test_backend.py @@ -21,7 +21,13 @@ 'expm', 'cosm', 'sinm', 'tanm', 'coshm', 'sinhm', 'tanhm', 'logm', 'funm', 'signm', 'sqrtm', 'expm_frechet', 'expm_cond', 'fractional_matrix_power', - 'khatri_rao' + 'khatri_rao', + # sketches + 'clarkson_woodruff_transform', + # special matrices + 'tri', 'tril', 'triu', 'toeplitz', 'circulant', 'hankel', + 'hadamard', 'leslie', 'kron', 'companion', + 'fiedler', 'fiedler_companion', 'convolution_matrix' ] diff --git a/scipy/linalg/tests/test_matfuncs.py b/scipy/linalg/tests/test_matfuncs.py index 034a65db9614..b804c110a214 100644 --- a/scipy/linalg/tests/test_matfuncs.py +++ b/scipy/linalg/tests/test_matfuncs.py @@ -16,7 +16,7 @@ import pytest import scipy.linalg -from scipy.linalg._matfuncs import (funm, signm, logm, sqrtm, +from scipy.linalg import (funm, signm, logm, sqrtm, fractional_matrix_power, expm, expm_frechet, expm_cond, norm, khatri_rao) from scipy.linalg import _matfuncs_inv_ssq