Skip to content

BUG: OneHotEncoder(string values) handles NaN as category on transform step #12018

@jorisvandenbossche

Description

@jorisvandenbossche

The OneHotEncoder (or at least the implementation of the encoder for string data) seems to handle missing values (NaN) as unseen categories in the transform step:

In [47]: X_train = np.array([['a', 'b', 'c']], dtype=object).T

In [48]: X_test = np.array([['a', 'b', 'd', np.nan]], dtype=object).T  # <--- has unknown category + NaN

In [49]: ohe = OneHotEncoder(handle_unknown='ignore')

In [51]: ohe.fit_transform(X_train).toarray()
Out[51]: 
array([[ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.]])

In [52]: ohe.transform(X_test)
...
ValueError: Found unknown categories [nan, 'd'] in column 0 during transform  # <--- both are listed as unknown

In [53]: ohe = OneHotEncoder(handle_unknown='ignore')

In [54]: ohe.fit_transform(X_train).toarray()
Out[54]: 
array([[ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.]])

In [56]: ohe.transform(X_test).toarray()
Out[56]: 
array([[ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  0.],
       [ 0.,  0.,  0.]])    # <--- NaN is also handled as unknown category instead of raising an error

I think missing values should not be regarded as an unknown category on transform step, but should behave the same as on fit time (i.e. raise an error at the moment, default behaviour might change depending on #11996).

So this is a bug I think, but luckily only in the new implementation of the encoder for strings (which is new for 0.20, so we can still fix this without having to worry about breaking behaviour), as for numerical values it works as expected:

In [82]: X_train_num = np.array([[1, 2, 3]]).T

In [83]: X_test_num = np.array([[1, 2, 4, np.nan]]).T

In [84]: ohe = OneHotEncoder(handle_unknown='ignore')

In [85]: ohe.fit_transform(X_train_num).toarray()
Out[85]: 
array([[ 1.,  0.,  0.],
       [ 0.,  1.,  0.],
       [ 0.,  0.,  1.]])

In [86]: ohe.transform(X_test_num).toarray()
...
ValueError: Input contains NaN, infinity or a value too large for dtype('float64').

(from https://medium.com/dunder-data/from-pandas-to-scikit-learn-a-new-exciting-workflow-e88e2271ef62, nice post @tdpetrou)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions