-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
ENH Adds Array API support to LinearDiscriminantAnalysis #22554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
66 commits
Select commit
Hold shift + click to select a range
749d5ec
ENH Adds ArrayAPI support to LinearDiscriminantAnalysis
thomasjpfan d8cab77
DOC Adds PR number
thomasjpfan f641535
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan 8240c95
ENH Full support for array api
thomasjpfan d485018
TST Adds array api test
thomasjpfan 101bbd9
TST Remove assert
thomasjpfan c9459ef
FIX Adds path for smallest normal
thomasjpfan d6d19f0
FIX Private expit
thomasjpfan 1f84146
TST Adds better test for get_namespace
thomasjpfan b3b9af9
TST Fix get_namespace
thomasjpfan abc25db
TST Adds more coverage
thomasjpfan 15b0e0b
Apply suggestions from code review
thomasjpfan 413fd47
CLN Address comments
thomasjpfan c668966
CLN Address comments
thomasjpfan e51c908
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan 1b6f260
STY Formatting errors
thomasjpfan 14e95c5
TST Adds import skips
thomasjpfan 3648bdc
STY Black
thomasjpfan cdfd8b9
TST Fixes docstring test
thomasjpfan 13972fb
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan ee1ce31
DOC Adds comment about performance
thomasjpfan c68f26c
ENH Raises error for casting and order on non-NumPy ArrayAPI arrays
thomasjpfan 638bcce
TST Simplifier filterwarning
thomasjpfan 71bf555
TST Remove error for order
thomasjpfan 0fabeab
ENH Use NumPy API directly
thomasjpfan d0ac9b6
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan b2622b1
STY Fix
thomasjpfan bdf04f3
CLN Do not special case array api as much
thomasjpfan a3d0ddc
CLN Fixes stability in cupy
thomasjpfan 6471045
FIX Fixes test errors
thomasjpfan 3108284
FIX Simplier test
thomasjpfan 55ea1a4
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan 9372326
REV Remove unused import
thomasjpfan a627424
ENH Use a _asarray_with_order
thomasjpfan d793e7c
ENH Adds out for numpy arrays
thomasjpfan badc4a0
TST Adds unit test for _asarray_with_order
thomasjpfan 71c0f5f
Merge branch 'main' into array_api_lda_pr
ogrisel 56abf45
Bootstrap minimal documentation for Array API support
ogrisel ae061b2
Fix doctest
ogrisel c6c9585
Link the user guide from the config docstring
ogrisel 580a6fc
DOC Add more information and reformat user guide
thomasjpfan 91e9681
ENH Use get to convert cupy to ndarray
thomasjpfan f6a5af4
ENH Adds _convert_estimator_to_ndarray
thomasjpfan d66576c
DOC Sphinx warning
thomasjpfan f155056
DOC Remove unneeded reference
thomasjpfan 769c513
ENH Clone estimator
thomasjpfan 5c938c3
TST Better variable names
thomasjpfan cc35dbd
DOC Add parent toc
thomasjpfan daa3e73
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan 58cc8c8
ENH Generic converter for estimators
thomasjpfan 77803c6
FIX Fixes failing test
thomasjpfan 08ca7f2
Merge branch 'main' into array_api_lda_pr
ogrisel 15e33a8
Move changelog entry to 1.2
ogrisel 3986c5b
Fix merge conflict resolution typos
ogrisel 93c4c8e
Merge branch 'main' into array_api_lda_pr
ogrisel c92d189
CLN Better implementation of take
thomasjpfan c997c30
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan 36cbc8a
DOC Update versionadded
thomasjpfan a5a68bb
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan 937f140
FIX Fixes merge error
thomasjpfan 607709d
Merge remote-tracking branch 'upstream/main' into array_api_lda_pr
thomasjpfan 055c6f8
DOC Add note about cupy.array_api and numpy.array_api
thomasjpfan 5bc6f76
Apply suggestions from code review
thomasjpfan 671f4f2
DOC Add docstring to solver parameter
thomasjpfan b809895
DOC Use better directive
thomasjpfan e8be4bc
DOC Use versionchanged directive
thomasjpfan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| .. Places parent toc into the sidebar | ||
|
|
||
| :parenttoc: True | ||
|
|
||
| .. include:: includes/big_toc_css.rst | ||
|
|
||
| =========== | ||
| Dispatching | ||
| =========== | ||
|
|
||
| .. toctree:: | ||
| :maxdepth: 2 | ||
|
|
||
| modules/array_api |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| .. Places parent toc into the sidebar | ||
|
|
||
| :parenttoc: True | ||
|
|
||
| .. _array_api: | ||
|
|
||
| ================================ | ||
| Array API support (experimental) | ||
| ================================ | ||
|
|
||
| .. currentmodule:: sklearn | ||
|
|
||
| The `Array API <https://data-apis.org/array-api/latest/>`_ specification defines | ||
| a standard API for all array manipulation libraries with a NumPy-like API. | ||
|
|
||
| Some scikit-learn estimators that primarily rely on NumPy (as opposed to using | ||
| Cython) to implement the algorithmic logic of their `fit`, `predict` or | ||
| `transform` methods can be configured to accept any Array API compatible input | ||
| datastructures and automatically dispatch operations to the underlying namespace | ||
| instead of relying on NumPy. | ||
|
|
||
| At this stage, this support is **considered experimental** and must be enabled | ||
| explicitly as explained in the following. | ||
|
|
||
| .. note:: | ||
| Currently, only `cupy.array_api` and `numpy.array_api` are known to work | ||
| with scikit-learn's estimators. | ||
|
|
||
| Example usage | ||
| ============= | ||
|
|
||
| Here is an example code snippet to demonstrate how to use `CuPy | ||
| <https://cupy.dev/>`_ to run | ||
| :class:`~discriminant_analysis.LinearDiscriminantAnalysis` on a GPU:: | ||
|
|
||
| >>> from sklearn.datasets import make_classification | ||
| >>> from sklearn import config_context | ||
| >>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis | ||
| >>> import cupy.array_api as xp | ||
|
|
||
| >>> X_np, y_np = make_classification(random_state=0) | ||
| >>> X_cu = xp.asarray(X_np) | ||
| >>> y_cu = xp.asarray(y_np) | ||
| >>> X_cu.device | ||
| <CUDA Device 0> | ||
|
|
||
| >>> with config_context(array_api_dispatch=True): | ||
| ... lda = LinearDiscriminantAnalysis() | ||
| ... X_trans = lda.fit_transform(X_cu, y_cu) | ||
| >>> X_trans.device | ||
| <CUDA Device 0> | ||
|
|
||
| After the model is trained, fitted attributes that are arrays will also be | ||
| from the same Array API namespace as the training data. For example, if CuPy's | ||
| Array API namespace was used for training, then fitted attributes will be on the | ||
| GPU. We provide a experimental `_estimator_with_converted_arrays` utility that | ||
| transfers an estimator attributes from Array API to a ndarray:: | ||
|
|
||
| >>> from sklearn.utils._array_api import _estimator_with_converted_arrays | ||
| >>> cupy_to_ndarray = lambda array : array._array.get() | ||
| >>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray) | ||
| >>> X_trans = lda_np.transform(X_np) | ||
| >>> type(X_trans) | ||
| <class 'numpy.ndarray'> | ||
ogrisel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| .. _array_api_estimators: | ||
|
|
||
| Estimators with support for `Array API`-compatible inputs | ||
| ========================================================= | ||
|
|
||
| - :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`) | ||
|
|
||
| Coverage for more estimators is expected to grow over time. Please follow the | ||
| dedicated `meta-issue on GitHub | ||
| <https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,3 +30,4 @@ User Guide | |
| computing.rst | ||
| model_persistence.rst | ||
| common_pitfalls.rst | ||
| dispatching.rst | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.