Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
749d5ec
ENH Adds ArrayAPI support to LinearDiscriminantAnalysis
thomasjpfan Feb 20, 2022
d8cab77
DOC Adds PR number
thomasjpfan Feb 20, 2022
f641535
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Feb 22, 2022
8240c95
ENH Full support for array api
thomasjpfan Feb 23, 2022
d485018
TST Adds array api test
thomasjpfan Feb 23, 2022
101bbd9
TST Remove assert
thomasjpfan Feb 23, 2022
c9459ef
FIX Adds path for smallest normal
thomasjpfan Feb 24, 2022
d6d19f0
FIX Private expit
thomasjpfan Feb 24, 2022
1f84146
TST Adds better test for get_namespace
thomasjpfan Feb 27, 2022
b3b9af9
TST Fix get_namespace
thomasjpfan Feb 27, 2022
abc25db
TST Adds more coverage
thomasjpfan Feb 27, 2022
15b0e0b
Apply suggestions from code review
thomasjpfan Mar 12, 2022
413fd47
CLN Address comments
thomasjpfan Mar 11, 2022
c668966
CLN Address comments
thomasjpfan Mar 12, 2022
e51c908
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Mar 12, 2022
1b6f260
STY Formatting errors
thomasjpfan Mar 12, 2022
14e95c5
TST Adds import skips
thomasjpfan Mar 12, 2022
3648bdc
STY Black
thomasjpfan Mar 12, 2022
cdfd8b9
TST Fixes docstring test
thomasjpfan Mar 12, 2022
13972fb
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Mar 18, 2022
ee1ce31
DOC Adds comment about performance
thomasjpfan Mar 18, 2022
c68f26c
ENH Raises error for casting and order on non-NumPy ArrayAPI arrays
thomasjpfan Mar 18, 2022
638bcce
TST Simplifier filterwarning
thomasjpfan Mar 18, 2022
71bf555
TST Remove error for order
thomasjpfan Mar 18, 2022
0fabeab
ENH Use NumPy API directly
thomasjpfan Mar 20, 2022
d0ac9b6
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Mar 20, 2022
b2622b1
STY Fix
thomasjpfan Mar 20, 2022
bdf04f3
CLN Do not special case array api as much
thomasjpfan Mar 21, 2022
a3d0ddc
CLN Fixes stability in cupy
thomasjpfan Mar 21, 2022
6471045
FIX Fixes test errors
thomasjpfan Mar 21, 2022
3108284
FIX Simplier test
thomasjpfan Mar 21, 2022
55ea1a4
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Mar 23, 2022
9372326
REV Remove unused import
thomasjpfan Mar 24, 2022
a627424
ENH Use a _asarray_with_order
thomasjpfan Mar 28, 2022
d793e7c
ENH Adds out for numpy arrays
thomasjpfan Mar 28, 2022
badc4a0
TST Adds unit test for _asarray_with_order
thomasjpfan Mar 28, 2022
71c0f5f
Merge branch 'main' into array_api_lda_pr
ogrisel Mar 29, 2022
56abf45
Bootstrap minimal documentation for Array API support
ogrisel Mar 29, 2022
ae061b2
Fix doctest
ogrisel Mar 29, 2022
c6c9585
Link the user guide from the config docstring
ogrisel Mar 29, 2022
580a6fc
DOC Add more information and reformat user guide
thomasjpfan Mar 29, 2022
91e9681
ENH Use get to convert cupy to ndarray
thomasjpfan Mar 29, 2022
f6a5af4
ENH Adds _convert_estimator_to_ndarray
thomasjpfan Mar 29, 2022
d66576c
DOC Sphinx warning
thomasjpfan Mar 29, 2022
f155056
DOC Remove unneeded reference
thomasjpfan Mar 29, 2022
769c513
ENH Clone estimator
thomasjpfan Mar 29, 2022
5c938c3
TST Better variable names
thomasjpfan Mar 29, 2022
cc35dbd
DOC Add parent toc
thomasjpfan Mar 29, 2022
daa3e73
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Jul 7, 2022
58cc8c8
ENH Generic converter for estimators
thomasjpfan Jul 7, 2022
77803c6
FIX Fixes failing test
thomasjpfan Jul 7, 2022
08ca7f2
Merge branch 'main' into array_api_lda_pr
ogrisel Aug 11, 2022
15e33a8
Move changelog entry to 1.2
ogrisel Aug 11, 2022
3986c5b
Fix merge conflict resolution typos
ogrisel Aug 11, 2022
93c4c8e
Merge branch 'main' into array_api_lda_pr
ogrisel Aug 23, 2022
c92d189
CLN Better implementation of take
thomasjpfan Aug 25, 2022
c997c30
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Aug 26, 2022
36cbc8a
DOC Update versionadded
thomasjpfan Aug 26, 2022
a5a68bb
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Sep 1, 2022
937f140
FIX Fixes merge error
thomasjpfan Sep 1, 2022
607709d
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan Sep 3, 2022
055c6f8
DOC Add note about cupy.array_api and numpy.array_api
thomasjpfan Sep 3, 2022
5bc6f76
Apply suggestions from code review
thomasjpfan Sep 5, 2022
671f4f2
DOC Add docstring to solver parameter
thomasjpfan Sep 5, 2022
b809895
DOC Use better directive
thomasjpfan Sep 5, 2022
e8be4bc
DOC Use versionchanged directive
thomasjpfan Sep 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions doc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def skip_if_matplotlib_not_installed(fname):
raise SkipTest(f"Skipping doctests for {basename}, matplotlib not installed")


def skip_if_cupy_not_installed(fname):
try:
import cupy # noqa
except ImportError:
basename = os.path.basename(fname)
raise SkipTest(f"Skipping doctests for {basename}, cupy not installed")


def pytest_runtest_setup(item):
fname = item.fspath.strpath
# normalize filename to use forward slashes on Windows for easier handling
Expand Down Expand Up @@ -147,6 +155,9 @@ def pytest_runtest_setup(item):
if fname.endswith(each):
skip_if_matplotlib_not_installed(fname)

if fname.endswith("array_api.rst"):
skip_if_cupy_not_installed(fname)


def pytest_configure(config):
# Use matplotlib agg backend during the tests including doctests
Expand Down
14 changes: 14 additions & 0 deletions doc/dispatching.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. Places parent toc into the sidebar

:parenttoc: True

.. include:: includes/big_toc_css.rst

===========
Dispatching
===========

.. toctree::
:maxdepth: 2

modules/array_api
75 changes: 75 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
.. Places parent toc into the sidebar

:parenttoc: True

.. _array_api:

================================
Array API support (experimental)
================================

.. currentmodule:: sklearn

The `Array API <https://data-apis.org/array-api/latest/>`_ specification defines
a standard API for all array manipulation libraries with a NumPy-like API.

Some scikit-learn estimators that primarily rely on NumPy (as opposed to using
Cython) to implement the algorithmic logic of their `fit`, `predict` or
`transform` methods can be configured to accept any Array API compatible input
datastructures and automatically dispatch operations to the underlying namespace
instead of relying on NumPy.

At this stage, this support is **considered experimental** and must be enabled
explicitly as explained in the following.

.. note::
Currently, only `cupy.array_api` and `numpy.array_api` are known to work
with scikit-learn's estimators.

Example usage
=============

Here is an example code snippet to demonstrate how to use `CuPy
<https://cupy.dev/>`_ to run
:class:`~discriminant_analysis.LinearDiscriminantAnalysis` on a GPU::

>>> from sklearn.datasets import make_classification
>>> from sklearn import config_context
>>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
>>> import cupy.array_api as xp

>>> X_np, y_np = make_classification(random_state=0)
>>> X_cu = xp.asarray(X_np)
>>> y_cu = xp.asarray(y_np)
>>> X_cu.device
<CUDA Device 0>

>>> with config_context(array_api_dispatch=True):
... lda = LinearDiscriminantAnalysis()
... X_trans = lda.fit_transform(X_cu, y_cu)
>>> X_trans.device
<CUDA Device 0>

After the model is trained, fitted attributes that are arrays will also be
from the same Array API namespace as the training data. For example, if CuPy's
Array API namespace was used for training, then fitted attributes will be on the
GPU. We provide a experimental `_estimator_with_converted_arrays` utility that
transfers an estimator attributes from Array API to a ndarray::

>>> from sklearn.utils._array_api import _estimator_with_converted_arrays
>>> cupy_to_ndarray = lambda array : array._array.get()
>>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray)
>>> X_trans = lda_np.transform(X_np)
>>> type(X_trans)
<class 'numpy.ndarray'>

.. _array_api_estimators:

Estimators with support for `Array API`-compatible inputs
=========================================================

- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)

Coverage for more estimators is expected to grow over time. Please follow the
dedicated `meta-issue on GitHub
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.
1 change: 1 addition & 0 deletions doc/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ User Guide
computing.rst
model_persistence.rst
common_pitfalls.rst
dispatching.rst
6 changes: 6 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ Changelog
:mod:`sklearn.discriminant_analysis`
....................................

- |MajorFeature| :class:`discriminant_analysis.LinearDiscriminantAnalysis` now
supports the `Array API <https://data-apis.org/array-api/latest/>`_ for
`solver="svd"`. Array API support is considered experimental and might evolve
without being subjected to our usual rolling deprecation cycle policy. See
:ref:`array_api` for more details. :pr:`22554` by `Thomas Fan`_.

- |Fix| Validate parameters only in `fit` and not in `__init__`
for :class:`discriminant_analysis.QuadraticDiscriminantAnalysis`.
:pr:`24218` by :user:`Stefanie Molin <stefmolin>`.
Expand Down
22 changes: 22 additions & 0 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
),
"enable_cython_pairwise_dist": True,
"array_api_dispatch": False,
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -50,6 +51,7 @@ def set_config(
display=None,
pairwise_dist_chunk_size=None,
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
):
"""Set global scikit-learn configuration

Expand Down Expand Up @@ -110,6 +112,14 @@ def set_config(

.. versionadded:: 1.1

array_api_dispatch : bool, default=None
Use Array API dispatching when inputs follow the Array API standard.
Default is False.

See the :ref:`User Guide <array_api>` for more details.

.. versionadded:: 1.2

See Also
--------
config_context : Context manager for global scikit-learn configuration.
Expand All @@ -129,6 +139,8 @@ def set_config(
local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size
if enable_cython_pairwise_dist is not None:
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
if array_api_dispatch is not None:
local_config["array_api_dispatch"] = array_api_dispatch


@contextmanager
Expand All @@ -140,6 +152,7 @@ def config_context(
display=None,
pairwise_dist_chunk_size=None,
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -199,6 +212,14 @@ def config_context(

.. versionadded:: 1.1

array_api_dispatch : bool, default=None
Use Array API dispatching when inputs follow the Array API standard.
Default is False.

See the :ref:`User Guide <array_api>` for more details.

.. versionadded:: 1.2

Yields
------
None.
Expand Down Expand Up @@ -234,6 +255,7 @@ def config_context(
display=display,
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
array_api_dispatch=array_api_dispatch,
)

try:
Expand Down
Loading