forked from scipy/scipy
-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Add uarray support for scipy.linalg
#114
Open
Smit-create
wants to merge
8
commits into
rgommers:master
Choose a base branch
from
Smit-create:uarray-linalg
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
9beb51a
ENH: Add initial `uarray` support for `linalg`
Smit-create 426fe82
CI: Enable tests for `rgommers/scipy`
Smit-create f17bd00
ENH: Add uarray support for `linalg._solvers`
Smit-create 59022ef
ENH: Add uarray support for `linalg._decomp` (eigenvalue problems)
Smit-create eb2ad0a
ENH: Add uarray support for matrix functions
Smit-create 815001e
ENH: Add uarray support for special matrices and sketches
Smit-create b2f933c
ENH: Add uarray support for decompositions
Smit-create 76dd1c7
ENH: Add uarray support for `linalg._basic`
Smit-create File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,16 +6,16 @@ on: | |
|
||
jobs: | ||
build: | ||
name: Build Gitpod Docker image | ||
name: Build Gitpod Docker image | ||
runs-on: ubuntu-latest | ||
environment: scipy-dev | ||
if: "github.repository_owner == 'scipy' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')" | ||
if: "github.repository_owner == 'rgommers' && !contains(github.event.head_commit.message, '[ci skip]') && !contains(github.event.head_commit.message, '[skip ci]') && !contains(github.event.head_commit.message, '[skip github]')" | ||
steps: | ||
- name: Clone repository | ||
uses: actions/checkout@v2 | ||
- name: Lint Docker | ||
- name: Lint Docker | ||
uses: brpaz/[email protected] | ||
with: | ||
with: | ||
dockerfile: ./tools/docker_dev/gitpod.Dockerfile | ||
- name: Get refs | ||
shell: bash | ||
|
@@ -49,6 +49,6 @@ jobs: | |
cache-to: type=local,dest=/tmp/.buildx-cache | ||
tags: | | ||
scipy/scipy-gitpod:${{ steps.getrefs.outputs.date }}-${{ steps.getrefs.outputs.branch}}-${{ steps.getrefs.outputs.sha8 }}, scipy/scipy-gitpod:latest | ||
- name: Image digest | ||
- name: Image digest | ||
# Return details of the image build: sha and shell | ||
run: echo ${{ steps.docker_build.outputs.digest }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._sketches import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import scipy._lib.uarray as ua | ||
from scipy.linalg import _api | ||
import numpy as np | ||
|
||
__all__ = [ | ||
'register_backend', 'set_backend', | ||
'set_global_backend', 'skip_backend' | ||
] | ||
|
||
|
||
class scalar_tuple_callable_array: | ||
""" | ||
Special case argument that can be either a scalar, tuple or array | ||
for __ua_convert__. | ||
""" | ||
pass | ||
|
||
class _ScipyLinalgBackend: | ||
__ua_domain__ = "numpy.scipy.linalg" | ||
|
||
|
||
@staticmethod | ||
def __ua_function__(method, args, kwargs): | ||
fn = getattr(_api, method.__name__, None) | ||
|
||
if fn is None: | ||
return NotImplemented | ||
|
||
return fn(*args, **kwargs) | ||
|
||
|
||
@ua.wrap_single_convertor | ||
def __ua_convert__(value, dispatch_type, coerce): | ||
if value is None: | ||
return None | ||
|
||
if dispatch_type is np.ndarray: | ||
if not coerce and not isinstance(value, np.ndarray): | ||
return NotImplemented | ||
|
||
return np.asarray(value) | ||
|
||
elif dispatch_type is np.dtype: | ||
return np.dtype(value) | ||
|
||
elif dispatch_type is scalar_tuple_callable_array: | ||
if (np.isscalar(value) or isinstance(value, (str, tuple)) or | ||
callable(value)): | ||
return value | ||
elif not coerce and not isinstance(value, np.ndarray): | ||
return NotImplemented | ||
|
||
return np.asarray(value) | ||
|
||
return value | ||
|
||
|
||
_named_backends = { | ||
'scipy': _ScipyLinalgBackend, | ||
} | ||
|
||
|
||
def _backend_from_arg(backend): | ||
if isinstance(backend, str): | ||
try: | ||
backend = _named_backends[backend] | ||
except KeyError as e: | ||
raise ValueError('Unknown backend {}'.format(backend)) from e | ||
|
||
if backend.__ua_domain__ != 'numpy.scipy.linalg': | ||
raise ValueError('Backend does not implement "numpy.scipy.linalg"') | ||
|
||
return backend | ||
|
||
|
||
def set_global_backend(backend, coerce=False, only=False, try_last=False): | ||
backend = _backend_from_arg(backend) | ||
ua.set_global_backend(backend, coerce=coerce, only=only, try_last=try_last) | ||
|
||
|
||
def register_backend(backend): | ||
backend = _backend_from_arg(backend) | ||
ua.register_backend(backend) | ||
|
||
|
||
def set_backend(backend, coerce=True, only=False): | ||
backend = _backend_from_arg(backend) | ||
return ua.set_backend(backend, coerce=coerce, only=only) | ||
|
||
|
||
def skip_backend(backend): | ||
backend = _backend_from_arg(backend) | ||
return ua.skip_backend(backend) | ||
|
||
|
||
set_global_backend('scipy', coerce=True, try_last=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import functools | ||
import numpy as np | ||
from scipy._lib.uarray import Dispatchable, all_of_type, create_multimethod | ||
from scipy.linalg import _api | ||
from scipy.linalg._backend import scalar_tuple_callable_array | ||
|
||
__all__ = [ | ||
# sketches | ||
'clarkson_woodruff_transform' | ||
] | ||
|
||
|
||
_create_linalg = functools.partial( | ||
create_multimethod, | ||
domain="numpy.scipy.linalg" | ||
) | ||
|
||
|
||
_mark_scalar_tuple_callable_array = functools.partial( | ||
Dispatchable, | ||
dispatch_type=scalar_tuple_callable_array, | ||
coercible=True | ||
) | ||
|
||
|
||
def _get_docs(func): | ||
func.__doc__ = getattr(_api, func.__name__).__doc__ | ||
return func | ||
|
||
|
||
def _input_sketch_seed_replacer(args, kwargs, dispatchables): | ||
def self_method(input_matrix, sketch_size, seed=None, *args, **kwargs): | ||
return (dispatchables[0], dispatchables[1], | ||
dispatchables[2]) + args, kwargs | ||
|
||
return self_method(*args, **kwargs) | ||
|
||
|
||
@_create_linalg(_input_sketch_seed_replacer) | ||
@all_of_type(np.ndarray) | ||
@_get_docs | ||
def clarkson_woodruff_transform(input_matrix, sketch_size, seed=None): | ||
return (input_matrix, Dispatchable(sketch_size, int), | ||
_mark_scalar_tuple_callable_array(seed)) | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy as np | ||
|
||
|
||
class _MockFunction: | ||
def __init__(self, return_value=None): | ||
self.number_calls = 0 | ||
self.return_value = return_value | ||
self.last_args = ([], {}) | ||
|
||
def __call__(self, *args, **kwargs): | ||
self.number_calls += 1 | ||
self.last_args = (args, kwargs) | ||
return self.return_value | ||
|
||
|
||
method_names = [ | ||
# sketches | ||
'clarkson_woodruff_transform' | ||
] | ||
|
||
for name in method_names: | ||
globals()[name] = _MockFunction(np.array([[0, 0], [1, 1]])) | ||
|
||
|
||
__ua_domain__ = "numpy.scipy.linalg" | ||
|
||
|
||
def __ua_function__(method, args, kwargs): | ||
fn = globals().get(method.__name__) | ||
return (fn(*args, **kwargs) if fn is not None else NotImplemented) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
import scipy.linalg | ||
from scipy.linalg import set_backend | ||
from scipy.linalg.tests import mock_backend | ||
|
||
from numpy.testing import assert_equal | ||
import pytest | ||
|
||
|
||
fnames = [ | ||
# sketches | ||
'clarkson_woodruff_transform' | ||
] | ||
|
||
|
||
funcs = [getattr(scipy.linalg, fname) for fname in fnames] | ||
mocks = [getattr(mock_backend, fname) for fname in fnames] | ||
|
||
|
||
@pytest.mark.parametrize("func, mock", zip(funcs, mocks)) | ||
def test_backend_call(func, mock): | ||
""" | ||
Checks fake backend dispatch. | ||
""" | ||
x = np.array([[0, 0], [1, 1]]) | ||
|
||
with set_backend(mock_backend, only=True): | ||
mock.number_calls = 0 | ||
y = func(x) | ||
assert_equal(y, mock.return_value) | ||
assert_equal(mock.number_calls, 1) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed, this:
asanyarray
internally, andThe
asanyarray
should be looked into per function (and possibly removed, as I think is likely a good idea forclarkson_woodruff
), the sparse matrices need handling with an explicit check here and don't need to be coerced.