Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add uarray support for scipy.linalg #114

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
name: Build base Docker image
runs-on: ubuntu-latest
environment: scipy-dev
if: "github.repository_owner == 'scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')"
if: "github.repository_owner == 'rgommers' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')"
steps:
- name: Clone repository
uses: actions/checkout@v2
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/gitpod.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ on:

jobs:
build:
name: Build Gitpod Docker image
name: Build Gitpod Docker image
runs-on: ubuntu-latest
environment: scipy-dev
if: "github.repository_owner == 'scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')"
if: "github.repository_owner == 'rgommers' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')"
steps:
- name: Clone repository
uses: actions/checkout@v2
- name: Lint Docker
- name: Lint Docker
uses: brpaz/[email protected]
with:
with:
dockerfile: ./tools/docker_dev/gitpod.Dockerfile
- name: Get refs
shell: bash
Expand Down Expand Up @@ -49,6 +49,6 @@ jobs:
cache-to: type=local,dest=/tmp/.buildx-cache
tags: |
scipy/scipy-gitpod:${{ steps.getrefs.outputs.date }}-${{ steps.getrefs.outputs.branch}}-${{ steps.getrefs.outputs.sha8 }}, scipy/scipy-gitpod:latest
- name: Image digest
- name: Image digest
# Return details of the image build: sha and shell
run: echo ${{ steps.docker_build.outputs.digest }}
4 changes: 2 additions & 2 deletions .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:
jobs:
Python-38-dbg:
name: Python 3.8-dbg
if: "github.repository == 'scipy/scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')"
if: "github.repository == 'rgommers/scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')"
runs-on: ubuntu-18.04
steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -53,7 +53,7 @@ jobs:

test_numpy_main:
name: NumPy main
if: "github.repository == 'scipy/scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]') && !contains(github.ref, 'maintenance/') && !contains(github.base_ref, 'maintenance/')"
if: "github.repository == 'rgommers/scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]') && !contains(github.ref, 'maintenance/') && !contains(github.base_ref, 'maintenance/')"
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/linux_meson.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
test_meson:
name: Meson build
# If using act to run CI locally the github object does not exist and the usual skipping should not be enforced
if: "github.repository != '' || github.repository == 'scipy/scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]') && !contains(github.ref, 'maintenance/') && !contains(github.base_ref, 'maintenance/')"
if: "github.repository != '' || github.repository == 'rgommers/scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]') && !contains(github.ref, 'maintenance/') && !contains(github.base_ref, 'maintenance/')"
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ on:
jobs:
test_macos:
name: macOS Test Matrix
if: "github.repository == 'scipy/scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')"
if: "github.repository == 'rgommers/scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')"
runs-on: macos-latest
strategy:
max-parallel: 3
Expand Down
3 changes: 2 additions & 1 deletion scipy/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@

from ._misc import *
from ._cythonized_array_utils import *
from ._backend import *
from ._basic import *
from ._decomp import *
from ._decomp_lu import *
Expand All @@ -208,13 +209,13 @@
from ._decomp_schur import *
from ._decomp_polar import *
from ._matfuncs import *
from ._multimethods import *
from .blas import *
from .lapack import *
from ._special_matrices import *
from ._solvers import *
from ._procrustes import *
from ._decomp_update import *
from ._sketches import *
from ._decomp_cossin import *

# Deprecated namespaces, to be removed in v2.0.0
Expand Down
1 change: 1 addition & 0 deletions scipy/linalg/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._sketches import *
96 changes: 96 additions & 0 deletions scipy/linalg/_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import scipy._lib.uarray as ua
from scipy.linalg import _api
import numpy as np

__all__ = [
'register_backend', 'set_backend',
'set_global_backend', 'skip_backend'
]


class scalar_tuple_callable_array:
"""
Special case argument that can be either a scalar, tuple or array
for __ua_convert__.
"""
pass

class _ScipyLinalgBackend:
__ua_domain__ = "numpy.scipy.linalg"


@staticmethod
def __ua_function__(method, args, kwargs):
fn = getattr(_api, method.__name__, None)

if fn is None:
return NotImplemented

return fn(*args, **kwargs)


@ua.wrap_single_convertor
def __ua_convert__(value, dispatch_type, coerce):
if value is None:
return None

if dispatch_type is np.ndarray:
if not coerce and not isinstance(value, np.ndarray):
return NotImplemented

return np.asarray(value)

elif dispatch_type is np.dtype:
return np.dtype(value)

elif dispatch_type is scalar_tuple_callable_array:
if (np.isscalar(value) or isinstance(value, (str, tuple)) or
callable(value)):
return value
elif not coerce and not isinstance(value, np.ndarray):
return NotImplemented

return np.asarray(value)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, this:

  • overrides functions which use asanyarray internally, and
  • breaks for sparse matrix inputs

The asanyarray should be looked into per function (and possibly removed, as I think is likely a good idea for clarkson_woodruff), the sparse matrices need handling with an explicit check here and don't need to be coerced.


return value


_named_backends = {
'scipy': _ScipyLinalgBackend,
}


def _backend_from_arg(backend):
if isinstance(backend, str):
try:
backend = _named_backends[backend]
except KeyError as e:
raise ValueError('Unknown backend {}'.format(backend)) from e

if backend.__ua_domain__ != 'numpy.scipy.linalg':
raise ValueError('Backend does not implement "numpy.scipy.linalg"')

return backend


def set_global_backend(backend, coerce=False, only=False, try_last=False):
backend = _backend_from_arg(backend)
ua.set_global_backend(backend, coerce=coerce, only=only, try_last=try_last)


def register_backend(backend):
backend = _backend_from_arg(backend)
ua.register_backend(backend)


def set_backend(backend, coerce=True, only=False):
backend = _backend_from_arg(backend)
return ua.set_backend(backend, coerce=coerce, only=only)


def skip_backend(backend):
backend = _backend_from_arg(backend)
return ua.skip_backend(backend)


set_global_backend('scipy', coerce=True, try_last=True)
44 changes: 44 additions & 0 deletions scipy/linalg/_multimethods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import functools
import numpy as np
from scipy._lib.uarray import Dispatchable, all_of_type, create_multimethod
from scipy.linalg import _api
from scipy.linalg._backend import scalar_tuple_callable_array

__all__ = [
# sketches
'clarkson_woodruff_transform'
]


_create_linalg = functools.partial(
create_multimethod,
domain="numpy.scipy.linalg"
)


_mark_scalar_tuple_callable_array = functools.partial(
Dispatchable,
dispatch_type=scalar_tuple_callable_array,
coercible=True
)


def _get_docs(func):
func.__doc__ = getattr(_api, func.__name__).__doc__
return func


def _input_sketch_seed_replacer(args, kwargs, dispatchables):
def self_method(input_matrix, sketch_size, seed=None, *args, **kwargs):
return (dispatchables[0], dispatchables[1],
dispatchables[2]) + args, kwargs

return self_method(*args, **kwargs)


@_create_linalg(_input_sketch_seed_replacer)
@all_of_type(np.ndarray)
@_get_docs
def clarkson_woodruff_transform(input_matrix, sketch_size, seed=None):
return (input_matrix, Dispatchable(sketch_size, int),
_mark_scalar_tuple_callable_array(seed))
rgommers marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions scipy/linalg/_sketches.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,5 +171,6 @@ def clarkson_woodruff_transform(input_matrix, sketch_size, seed=None):
In Foundations and Trends in Theoretical Computer Science, 2014.

"""
print(type(input_matrix))
S = cwt_matrix(sketch_size, input_matrix.shape[0], seed)
return S.dot(input_matrix)
3 changes: 3 additions & 0 deletions scipy/linalg/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ _cythonized_array_utils = py3.extension_module('_cythonized_array_utils',

python_sources = [
'__init__.py',
'_api.py',
'_backend.py',
'_basic.py',
'_decomp.py',
'_decomp_cholesky.py',
Expand All @@ -260,6 +262,7 @@ python_sources = [
'_matfuncs_inv_ssq.py',
'_matfuncs_sqrtm.py',
'_misc.py',
'_multimethods.py',
'_procrustes.py',
'_sketches.py',
'_solvers.py',
Expand Down
2 changes: 2 additions & 0 deletions scipy/linalg/tests/meson.build
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
python_sources = [
'__init__.py',
'mock_backend.py',
'test_backend.py',
'test_basic.py',
'test_blas.py',
'test_cython_blas.py',
Expand Down
30 changes: 30 additions & 0 deletions scipy/linalg/tests/mock_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np


class _MockFunction:
def __init__(self, return_value=None):
self.number_calls = 0
self.return_value = return_value
self.last_args = ([], {})

def __call__(self, *args, **kwargs):
self.number_calls += 1
self.last_args = (args, kwargs)
return self.return_value


method_names = [
# sketches
'clarkson_woodruff_transform'
]

for name in method_names:
globals()[name] = _MockFunction(np.array([[0, 0], [1, 1]]))


__ua_domain__ = "numpy.scipy.linalg"


def __ua_function__(method, args, kwargs):
fn = globals().get(method.__name__)
return (fn(*args, **kwargs) if fn is not None else NotImplemented)
31 changes: 31 additions & 0 deletions scipy/linalg/tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
import scipy.linalg
from scipy.linalg import set_backend
from scipy.linalg.tests import mock_backend

from numpy.testing import assert_equal
import pytest


fnames = [
# sketches
'clarkson_woodruff_transform'
]


funcs = [getattr(scipy.linalg, fname) for fname in fnames]
mocks = [getattr(mock_backend, fname) for fname in fnames]


@pytest.mark.parametrize("func, mock", zip(funcs, mocks))
def test_backend_call(func, mock):
"""
Checks fake backend dispatch.
"""
x = np.array([[0, 0], [1, 1]])

with set_backend(mock_backend, only=True):
mock.number_calls = 0
y = func(x)
assert_equal(y, mock.return_value)
assert_equal(mock.number_calls, 1)