Skip to content

Conversation

@thomasjpfan
Copy link
Member

Reference Issues/PRs

Fixes #24860

What does this implement/fix? Explain your changes.

This PR enables the feature selectors to preserve the DataFrame's dtype in transform. Implementation-wise, SelectorMixin will only preserve the DataFrame's dtype if:

  1. The input to transform is a DataFrame
  2. The selector is configured to output DataFrames with set_output.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

My only concern is the logic becoming too complex and harder for third party developers to recreate.

output_config_dense = _get_output_config("transform", estimator=self)["dense"]
if hasattr(X, "iloc") and output_config_dense == "pandas":
# Only check feature names and n_features when output is a dataframe
# This allow _transform to preserve `X`'s dtype
Copy link
Member

@betatim betatim Dec 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't quite work out why this comment is true :-/

The important part seems to be that we do not call X = self._validate_data(...), not that we call self._check_feature_names(X, reset=False) or self._check_n_features(X, reset=False). At least it all seems to work if I replace those two lines with a pass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does work with pass, but we still want to check if the input is valid. For example, we enforce that the column names are consistent between fit and transform in all the transformers:

from sklearn.feature_selection import SelectPercentile
import pandas as pd

X = pd.DataFrame({"a": [1, 10, 5], "b": [3, 1, 6], "c": [5, 6, 3], "d": [5, 1, 6]})
y = [1, 0, 1]

selector = SelectPercentile(percentile=50).set_output(transform="pandas")
selector.fit_transform(X, y)

X_test = X.copy()
X_test.columns = ["f1", "f2", "f3", "f4"]

# Errors because column names are not consistent.
selector.transform(X_test)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to do the validation. The thing that puzzled me is where/what magic was happening to make things work with respect to the PR. I like your updated comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. What makes this PR work is the update to _transform. As long as X is a dataframe, _safe_indexing(X, mask, axis=1) will mask it correctly and return a dataframe.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly. This is what I thought, and then I spent a bunch of time trying to understand why the stuff here was required (the comment made me believe it was required). All good now.

Copy link
Member

@betatim betatim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable, but I am a bit confused about why this works. See my comment above

@adrinjalali
Copy link
Member

Better somebody who's more familiar with this part of the code to review. I don't feel comfortable signing off on this w/o needing to spend quite a bit of time to do a nice review.

@thomasjpfan
Copy link
Member Author

In an offline discussion with Adrin, we thought it would be better to add a new keyword to _validate_data to move more complexity into _validate_data itself.

I updated this PR to add a cast_to_ndarray to _validate_data, which by default will perform the checks and cast X and y to ndarrays. If cast_to_ndarray=False, then only feature_names_in_ and n_features_in_ are checked and the data is left unchanged.

Copy link
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much nicer to me now!

Comment on lines +23 to +24
if self.step >= 1:
mask[:: self.step] = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a bug you're fixing here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new feature to StepSelector, where step < 1 means the mask is all False and no features are selected. Note that StepSelector is only used for testing.

I updated the docstring to mention this fact: 1e5b8ee (#25102)

@adrinjalali adrinjalali added the Waiting for Second Reviewer First reviewer is done, need a second one! label Jan 2, 2023
@betatim
Copy link
Member

betatim commented Feb 9, 2023

Needs conflicts resolving and a CI re-run. Otherwise looks good to me.

@lorentzenchr lorentzenchr merged commit 677a4cf into scikit-learn:main Feb 19, 2023
AdarshPrusty7 added a commit to AdarshPrusty7/GSGP that referenced this pull request Mar 6, 2023
* ENH Raise NotFittedError in get_feature_names_out for MissingIndicator, KBinsDiscretizer, SplineTransformer, DictVectorizer (scikit-learn#25402)

Co-authored-by: Alex <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>

* DOC Update date and contributors list for v1.2.1 (scikit-learn#25459)

* DOC Make MeanShift documentation clearer (scikit-learn#25305)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* Finishes boolean and arithmetic creation

* Skeleton for traditional GP

* DOC Reorder whats_new/v1.2.rst (scikit-learn#25461)

Follow-up of scikit-learn#25459

Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

* FIX fix faulty test in `cross_validate` that used the wrong estimator (scikit-learn#25456)

* ENH Raise NotFittedError in get_feature_names_out for estimators that use ClassNamePrefixFeatureOutMixin and SelectorMixin (scikit-learn#25308)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* EFF Improve IsolationForest predict time (scikit-learn#25186)

Co-authored-by: Felipe Breve Siola <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Tim Head <[email protected]>

* MAINT refactor spectral_clustering to call SpectralClustering (scikit-learn#25392)

* TST reduce warnings in test_logistic.py (scikit-learn#25469)

* CI Build doc on CircleCI (scikit-learn#25466)

* DOC Update news footer for 1.2.1 (scikit-learn#25472)

* MAINT Validate parameter for `sklearn.cluster.cluster_optics_xi` (scikit-learn#25385)

Co-authored-by: adossantosalfam <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>

* MAINT Parameters validation for additive_chi2_kernel (scikit-learn#25424)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* Initial Program Creation

* CI Include linting in CircleCI (scikit-learn#25475)

* MAINT Update version number to 1.2.1 in SECURITY.md (scikit-learn#25471)

* TST Sets random_state for test_logistic.py (scikit-learn#25446)

* MAINT Remove -Wcpp warnings when compiling sklearn.decomposition._online_lda_fast (scikit-learn#25020)

Co-authored-by: Julien Jerphanion <[email protected]>

* FIX Support readonly sparse datasets for `manhattan_distances`  (scikit-learn#25432)

* TST Add non-regression test for scikit-learn#7981

This reproducer is adapted from the one of this message:
scikit-learn#7981 (comment)

Co-authored-by: Loïc Estève <[email protected]>

* FIX Support readonly sparse datasets for manhattan

* DOC Add entry in whats_new/v1.2.rst for 1.2.1

* FIX Fix comment

* Update sklearn/metrics/tests/test_pairwise.py

Co-authored-by: Christian Lorentzen <[email protected]>

* DOC Move entry to whats_new/v1.3.rst

* Update sklearn/metrics/tests/test_pairwise.py

Co-authored-by: Olivier Grisel <[email protected]>

Co-authored-by: Loïc Estève <[email protected]>
Co-authored-by: Christian Lorentzen <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>

* MAINT dynamically expose kulsinski and remove support in BallTree (scikit-learn#25417)

Co-authored-by: Loïc Estève <[email protected]>
Co-authored-by: Julien Jerphanion <[email protected]>
closes scikit-learn#25212

* DOC Adds CirrusCI badge to readme (scikit-learn#25483)

* CI add linter display name (scikit-learn#25485)

* DOC update description of X in `FunctionTransformer.transform()`  (scikit-learn#24844)

* MAINT remove -Wcpp warnings when compiling sklearn.preprocessing._csr_polynomial_expansion (scikit-learn#25041)

* DOC more didactic example of bisecting kmeans (scikit-learn#25494)

Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Arturo Amor <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

* ENH csr_row_norms optimization (scikit-learn#24426)

Co-authored-by: Julien Jerphanion <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

* TST Allow callables as valid parameter regarding cloning estimator (scikit-learn#25498)

Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Loïc Estève <[email protected]>
Co-authored-by: From: Tim Head <[email protected]>

* DOC Fixes sphinx search on website (scikit-learn#25504)

* FIX make IsotonicRegression always predict NumPy arrays (scikit-learn#25500)



Co-authored-by: Thomas J. Fan <[email protected]>

* FEA Add Gamma deviance as loss function to HGBT (scikit-learn#22409)

* FEA add gamma loss to HGBT

* DOC add whatsnew

* CLN address review comments

* TST make test_gamma pass by not testing out-of-sample

* TST compare gamma and poisson to LightGBM

* TST fix test_gamma by comparing to MSE HGBT instead of Poisson HGBT

* TST fix for test_same_predictions_regression for poisson

* CLN address review comments

* CLN nits

* CLN better comments

* TST use pytest.param with skip mark

* TST Correct conditional test parametrization mark

Co-authored-by: Christian Lorentzen <[email protected]>

* CI Trigger CI

Builds currently fail because requests to Azure Ubuntu repository
timeout.

* DOC add comment for lax comparison with LightGBM

* CLN tuple needs trailing comma

---------

Co-authored-by: Julien Jerphanion <[email protected]>

* MAINT Remove -Wsign-compare warnings when compiling sklearn.tree._tree (scikit-learn#25507)

* MAINT add more intuition on OAS computation based on literature (scikit-learn#23867)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* CI Allow cirrus arm tests to run with cd build commit tag (scikit-learn#25514)

* CI Upload ARM wheels from CirrusCI to nightly and staging index (scikit-learn#25513)



Co-authored-by: Olivier Grisel <[email protected]>

* MAINT Remove -Wcpp warnings from sklearn.utils._seq_dataset (scikit-learn#25406)

* FIX Fixes linux ARM CI on CirrusCI (scikit-learn#25536)

* DOC Fix grammatical mistake in `mixture` module (scikit-learn#25541)

* DOC add missing trailing colon (scikit-learn#25542)

* MAINT Parameters validation for sklearn.datasets.make_classification (scikit-learn#25474)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* MNT Expose allow_nan tag in bagging (scikit-learn#25506)

* MAINT Clean-up comments and rename variables in `_middle_term_sparse_sparse_{32, 64}` (scikit-learn#25449)

Co-authored-by: Julien Jerphanion <[email protected]>

* DOC: remove incorrect statement (scikit-learn#25544)

* MAINT Parameters validation for reconstruct_from_patches_2d (scikit-learn#25384)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* MAINT Parameter validation for sklearn.metrics.d2_pinball_score (scikit-learn#25414)

Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Parameters validation for spectral_clustering (scikit-learn#25378)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* MAINT Parameters validation for sklearn.datasets.fetch_kddcup99 (scikit-learn#25463)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* DOC Update MLPRegressor docs (scikit-learn#25556)

Co-authored-by: Ian Thompson <[email protected]>

* DOC Update docs for KMeans (scikit-learn#25546)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* FIX BisectingKMeans crashes randomly (scikit-learn#25563)

Fixes scikit-learn#25505

* ENH BaseLabelPropagation to accept sparse matrices (scikit-learn#19664)

Co-authored-by: Kaushik Amar Das <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>

* MAINT Remove travis ci config and related doc (scikit-learn#25562)

* DOC Add pynndescent to Approximate nearest neighbors in TSNE example (scikit-learn#25480)


Co-authored-by: Olivier Grisel <[email protected]>

* DOC Add docstring example to make_regression (scikit-learn#25551)

Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT ensure that pos_label support all possible types (scikit-learn#25317)

* MAINT Parameters validation for sklearn.metrics.f1_score (scikit-learn#25557)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* ENH Adds `class_names` to `tree.export_text` (scikit-learn#25387)

Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>

* MAINT Replace cnp.ndarray with memory views in sklearn.tree._tree (where possible) (scikit-learn#25540)

* DOC Change print format in TSNE example (scikit-learn#25569)

Co-authored-by: Olivier Grisel <[email protected]>

* FIX ColumnTransformer supports empty selection for pandas output (scikit-learn#25570)

Co-authored-by: Julien Jerphanion <[email protected]>

* DOC fix docstring of _plain_sgd (scikit-learn#25573)

* FIX Enable setting of sub-parameters for deprecated base_estimator param (scikit-learn#25477)

* DOC Improve minor and bug-fix release processes documentation (scikit-learn#25457)

Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Remove ReadonlyArrayWrapper from _loss module (scikit-learn#25555)

* MAINT Remove ReadonlyArrayWrapper from _loss module

* CLN Remove comments about Cython 3.0

* MAINT Remove ReadonlyArrayWrapper from _kmeans (scikit-learn#25554)

* MAINT Remove ReadonlyArrayWrapper from _kmeans

* more const and remove blas compile warnings

* CLN Adds comment about casting to non const pointers

* Update sklearn/utils/_cython_blas.pyx

* MAINT Remove ReadonlyArrayWrapper from DistanceMetric (scikit-learn#25553)

* DOC improve stop_words description w.r.t. max_df range in CountVectorizer (scikit-learn#25489)

* MAINT Removes ReadOnlyWrapper (scikit-learn#25586)

* MAINT Parameters validation for sklearn.metrics.log_loss (scikit-learn#25577)

* MAINT Adds comments and better naming into tree code (scikit-learn#25576)

* MAINT Adds comments and better naming into tree code

* CLN Use feature_values instead of Xf

* Apply suggestions from code review

Co-authored-by: Adam Li <[email protected]>

* DOC Improve comment from review

* Apply suggestions from code review

Co-authored-by: Julien Jerphanion <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>

---------

Co-authored-by: Adam Li <[email protected]>
Co-authored-by: Julien Jerphanion <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>

* FIX error when deserialzing a Tree instance from a read only buffer (scikit-learn#25585)

* DOC: fix typo in California Housing dataset description (scikit-learn#25613)

* ENH: Update KDTree, and example documentation (scikit-learn#25482)

* ENH: Update KDTree, and example documentation

* ENH: Add valid metric function and reference doc

* CHG: Documentation update

Co-authored-by: Adam Li <[email protected]>

* CHG: make valid metric property and fix doc string

* FIX: documentation, and add code example

* ENH: Change valid metric to class method, and doc

* ENH: Change valid metric class variable, and doc

* FIX: documentation error

* FIX: documentation error

* CHG: Use class method for valid metrics

* FIX: CI problems

---------

Co-authored-by: Adam Li <[email protected]>
Co-authored-by: Julien Jerphanion <[email protected]>

* TST Common test for checking estimator deserialization from a read only buffer (scikit-learn#25624)

* DOC fix comment in plot_logistic_l1_l2_sparsity.py (scikit-learn#25633)

Co-authored-by: Thomas J. Fan <[email protected]>

* DOC Places governance in navigation bar (scikit-learn#25618)

* MAINT Check pyproject toml is consistent with min_dependencies (scikit-learn#25610)

* MAINT Check pyproject toml is consistent with min_dependencies

* CLN Make it clear that only SciPy and Cython are checked

* CLN Revert auto formatter

* MAINT Use newest NumPy C API in tree._criterion (scikit-learn#25615)

* MAINT Use newest NumPy C API in tree._criterion

* FIX Use pointer for children

* FIX Fixes check_array nonfinite checks with ArrayAPI specification (scikit-learn#25619)

* FIX Fixes check_array nonfinite checks with ArrayAPI specification

* DOC Adds PR number

* FIX Test on both cupy and numpy

* DOC Correctly docstring in StackingRegressor.fit_transform (scikit-learn#25599)

* MAINT Remove Cython compilation warnings ahead of Cython3.0 release (scikit-learn#25621)

* ENH Preserve DataFrame dtypes in transform for feature selectors (scikit-learn#25102)

* FIX report properly n_iter_ when warm_start=True (scikit-learn#25443)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* DOC fix typo in KMeans's param. (scikit-learn#25649)

* FIX use const memory views in hist_gradient_boosting predictor (scikit-learn#25650)

* DOC modified the graph for better readability (scikit-learn#25644)

* MAINT Removes upper limit on setuptools (scikit-learn#25651)

* DOC improve the `warm_start` glossary entry (scikit-learn#25523)

* DOC Update governance document for SLEP020 (scikit-learn#25663)



Co-authored-by: Tim Head <[email protected]>
Co-authored-by: Christian Lorentzen <[email protected]>

* FIX renormalization of y_pred inside log_loss (scikit-learn#25299)

* Remove renormalization of y_pred inside log_loss

* Deprecate eps parameter in log_loss

* ENH Allows target to be pandas nullable dtypes (scikit-learn#25638)

* DOC unify usage of 'w.r.t.' (scikit-learn#25683)

* MAINT Parameters validation for metrics.max_error (scikit-learn#25679)

* MAINT Parameters validation for datasets.make_friedman1 (scikit-learn#25674)

Co-authored-by: jeremie du boisberranger <[email protected]>

* MAINT Parameters validation for mean_pinball_loss (scikit-learn#25685)

Co-authored-by: jeremie du boisberranger <[email protected]>

* DOC Specify behavior of None for CountVectorizer (scikit-learn#25678)

* DOC Specify behaviour of None for TfIdfVectorizer max_features parameter (scikit-learn#25676)

Co-authored-by: Guillaume Lemaitre <[email protected]>

* MAINT Set random state for plot_anomaly_comparison (scikit-learn#25675)

* MAINT Parameters validation for cluster.mean_shift (scikit-learn#25684)

Co-authored-by: jeremie du boisberranger <[email protected]>

* MAINT Parameters validation for sklearn.metrics.jaccard_score (scikit-learn#25680)

Co-authored-by: jeremie du boisberranger <[email protected]>

* DOC Add the custom compiler section back (scikit-learn#25667)

Co-authored-by: Thomas J. Fan <[email protected]>

* MAINT Parameters validation for precision_recall_fscore_support (scikit-learn#25681)

Co-authored-by: jeremie du boisberranger <[email protected]>

* FIX Allow negative tol in SequentialFeatureSelector (scikit-learn#25664)

* MAINT Replace deprecated cython conditional compilation (scikit-learn#25654)



Co-authored-by: Guillaume Lemaitre <[email protected]>

* DOC fix formatting typo in related_projects (scikit-learn#25706)

* MAINT Parameters validation for metrics.mean_absolute_percentage_error (scikit-learn#25695)

* MAINT Parameters validation for metrics.precision_recall_curve (scikit-learn#25698)

Co-authored-by: jeremie du boisberranger <[email protected]>

* MAINT Parameter Validation for metrics.precision_score (scikit-learn#25708)

Co-authored-by: jeremie du boisberranger <[email protected]>

* CI Stablize build with random_state (scikit-learn#25701)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Remove -Wcpp warnings when compiling arrayfuncs (scikit-learn#25415)

Co-authored-by: Thomas J. Fan <[email protected]>

* DOC Add scikit-learn-intelex to related projects (scikit-learn#23766)

Co-authored-by: Adrin Jalali <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>

* ENH Support float32 in SGDClassifier and SGDRegressor (scikit-learn#25587)

* FIX Raise appropriate attribute error in ensemble (scikit-learn#25668)

* FIX Allow OrdinalEncoder's encoded_missing_value set to the cardinality (scikit-learn#25704)

* ENH Let csr_row_norms support multi-thread (scikit-learn#25598)

Co-authored-by: Jérémie du Boisberranger <[email protected]>
Co-authored-by: Vincent M <[email protected]>

* MAINT Parameter Validation for feature_selection.chi2 (scikit-learn#25719)

Co-authored-by: jeremiedbb <[email protected]>

* MAINT Parameter Validation for feature_selection.f_classif (scikit-learn#25720)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Parameters validation for sklearn.metrics.matthews_corrcoef (scikit-learn#25712)

Co-authored-by: jeremiedbb <[email protected]>

* MAINT parameter validation for sklearn.datasets.dump_svmlight_file (scikit-learn#25726)

Co-authored-by: jeremiedbb <[email protected]>

* MAINT Clean dead code in build helpers (scikit-learn#25661)

* MAINT Use newest NumPy C API in metrics._dist_metrics (scikit-learn#25702)

* CI Adds permissions to workflows that use GITHUB_TOKEN (scikit-learn#25600)

Co-authored-by: Olivier Grisel <[email protected]>

* FIX Improves error message in partial_fit when early_stopping=True (scikit-learn#25694)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* DOC Makes navbar static (scikit-learn#25688)

* MAINT Remove redundant sparse square euclidian distances function (scikit-learn#25731)

* MAINT Use float64 for accumulators in WeightVector* (scikit-learn#25721)

* API make PatchExtractor being a real scikit-learn transformer (scikit-learn#24230)

Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Update pyparsing.py to use bool instead of double negation (scikit-learn#25724)

* API Deprecates values in partial_dependence in favor of pdp_values (scikit-learn#21809)

Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>

* API Use grid_values instead of pdp_values in partial_dependence (scikit-learn#25732)

* MAINT remove np.product and inf/nan aliases in favor of canonical names (scikit-learn#25741)

* MAINT Parameters validation for metrics.label_ranking_loss (scikit-learn#25742)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Parameters validation for metrics.coverage_error (scikit-learn#25748)

* MAINT Parameters validation for metrics.dcg_score (scikit-learn#25749)

* MAINT replace cnp.ndarray with memory views in _fast_dict (scikit-learn#25754)

* MAINT Parameter Validation for feature_selection.f_regression (scikit-learn#25736)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Parameters validation for feature_selection.r_regression (scikit-learn#25734)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Parameter Validation for metrics.get_scorer (scikit-learn#25738)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* DOC Move allowing pandas nullable dtypes to 1.2.2 (scikit-learn#25692)

* MAINT replace cnp.ndarray with memory views in sparsefuncs_fast (scikit-learn#25764)

* MAINT parameter validation for sklearn.datasets.fetch_covtype (scikit-learn#25759)

Co-authored-by: Jérémie du Boisberranger <[email protected]>

* MAINT Define centralized generic, but with explicit precision, types (scikit-learn#25739)

* CI Disable network when SciPy requires it (scikit-learn#25743)

* CI Open issue when arm wheel fails on CirrusCI (scikit-learn#25620)

* ENH Speed-up expected mutual information (scikit-learn#25713)

Co-authored-by: Kshitij Mathur <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Omar Salman <[email protected]>

* FIX add retry mechanism to handle quotechar in read_csv (scikit-learn#25511)

* Merge Population Creation (#1)

---------

Co-authored-by: Alex Buzenet <[email protected]>
Co-authored-by: Alex <[email protected]>
Co-authored-by: Guillaume Lemaitre <[email protected]>
Co-authored-by: Julien Jerphanion <[email protected]>
Co-authored-by: Adam Kania <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>
Co-authored-by: Shady el Gewily <[email protected]>
Co-authored-by: John Pangas <[email protected]>
Co-authored-by: Felipe Siola <[email protected]>
Co-authored-by: Felipe Breve Siola <[email protected]>
Co-authored-by: Tim Head <[email protected]>
Co-authored-by: Christian Lorentzen <[email protected]>
Co-authored-by: Loïc Estève <[email protected]>
Co-authored-by: Anthony22-dev <[email protected]>
Co-authored-by: adossantosalfam <[email protected]>
Co-authored-by: Xiao Yuan <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Omar Salman <[email protected]>
Co-authored-by: Rahil Parikh <[email protected]>
Co-authored-by: Gael Varoquaux <[email protected]>
Co-authored-by: Arturo Amor <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>
Co-authored-by: Meekail Zain <[email protected]>
Co-authored-by: davidblnc <[email protected]>
Co-authored-by: Changyao Chen <[email protected]>
Co-authored-by: Nicola Fanelli <[email protected]>
Co-authored-by: Vincent M <[email protected]>
Co-authored-by: partev <[email protected]>
Co-authored-by: ouss1508 <[email protected]>
Co-authored-by: ashah002 <[email protected]>
Co-authored-by: Ahmedbgh <[email protected]>
Co-authored-by: Pooja M <[email protected]>
Co-authored-by: Ian Thompson <[email protected]>
Co-authored-by: Ian Thompson <[email protected]>
Co-authored-by: SANJAI_3 <[email protected]>
Co-authored-by: Kaushik Amar Das <[email protected]>
Co-authored-by: Kaushik Amar Das <[email protected]>
Co-authored-by: Nawazish Alam <[email protected]>
Co-authored-by: William M <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>
Co-authored-by: JanFidor <[email protected]>
Co-authored-by: Adam Li <[email protected]>
Co-authored-by: Logan Thomas <[email protected]>
Co-authored-by: Vyom Pathak <[email protected]>
Co-authored-by: as-90 <[email protected]>
Co-authored-by: Marvin Krawutschke <[email protected]>
Co-authored-by: Haesun Park <[email protected]>
Co-authored-by: Christine P. Chai <[email protected]>
Co-authored-by: Christian Veenhuis <[email protected]>
Co-authored-by: Sortofamudkip <[email protected]>
Co-authored-by: sonnivs <[email protected]>
Co-authored-by: Ali H. El-Kassas <[email protected]>
Co-authored-by: Yusuf Raji <[email protected]>
Co-authored-by: Tabea Kossen <[email protected]>
Co-authored-by: Pooja Subramaniam <[email protected]>
Co-authored-by: JuliaSchoepp <[email protected]>
Co-authored-by: Jack McIvor <[email protected]>
Co-authored-by: zeeshan lone <[email protected]>
Co-authored-by: Max Halford <[email protected]>
Co-authored-by: Adrin Jalali <[email protected]>
Co-authored-by: genvalen <[email protected]>
Co-authored-by: Shiva chauhan <[email protected]>
Co-authored-by: Dayne <[email protected]>
Co-authored-by: Ralf Gommers <[email protected]>
Co-authored-by: Kshitij Mathur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:feature_selection Waiting for Second Reviewer First reviewer is done, need a second one!

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Preserving dtypes for DataFrame output by transformers that do not modify the input values

4 participants