Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ Changelog
:class:`~sklearn.random_projection.SparseRandomProjection`. :pr:`21330` by
:user:`Loïc Estève <lesteve>`.

:mod:`sklearn.tree`
...................

- |Fix| Support loading pickles of decision tree models when the pickle has
been generated on a platform with a different bitness. A typical example is
to train and pickle the model on 64 bit machine and load the model on a 32
bit machine for prediction. :pr:`21552` by :user:`Loïc Estève <lesteve>`.

Code and Documentation Contributors
-----------------------------------

Expand Down
157 changes: 142 additions & 15 deletions sklearn/tree/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ from libc.string cimport memcpy
from libc.string cimport memset
from libc.stdint cimport SIZE_MAX

import struct

import numpy as np
cimport numpy as np
np.import_array()
Expand Down Expand Up @@ -583,9 +585,13 @@ cdef class Tree:
def __get__(self):
return self._get_value_ndarray()[:self.node_count]

def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes,
int n_outputs):
def __cinit__(self, int n_features, np.ndarray n_classes, int n_outputs):
"""Constructor."""
cdef SIZE_t dummy = 0
size_t_dtype = np.array(dummy).dtype

n_classes = _check_n_classes(n_classes, size_t_dtype)

# Input/Output layout
self.n_features = n_features
self.n_outputs = n_outputs
Expand Down Expand Up @@ -644,19 +650,12 @@ cdef class Tree:
value_shape = (node_ndarray.shape[0], self.n_outputs,
self.max_n_classes)

if (node_ndarray.dtype != NODE_DTYPE):
# possible mismatch of big/little endian due to serialization
# on a different architecture. Try swapping the byte order.
node_ndarray = node_ndarray.byteswap().newbyteorder()
if (node_ndarray.dtype != NODE_DTYPE):
raise ValueError('Did not recognise loaded array dytpe')

if (node_ndarray.ndim != 1 or
not node_ndarray.flags.c_contiguous or
value_ndarray.shape != value_shape or
not value_ndarray.flags.c_contiguous or
value_ndarray.dtype != np.float64):
raise ValueError('Did not recognise loaded array layout')
node_ndarray = _check_node_ndarray(node_ndarray, expected_dtype=NODE_DTYPE)
value_ndarray = _check_value_ndarray(
value_ndarray,
expected_dtype=np.dtype(np.float64),
expected_shape=value_shape
)

self.capacity = node_ndarray.shape[0]
if self._resize_c(self.capacity) != 0:
Expand Down Expand Up @@ -1219,6 +1218,134 @@ cdef class Tree:
total_weight)


def _check_n_classes(n_classes, expected_dtype):
if n_classes.ndim != 1:
raise ValueError(
f"Wrong dimensions for n_classes from the pickle: "
f"expected 1, got {n_classes.ndim}"
)

if n_classes.dtype == expected_dtype:
return n_classes

# Handles both different endianness and different bitness
if n_classes.dtype.kind == "i" and n_classes.dtype.itemsize in [4, 8]:
return n_classes.astype(expected_dtype, casting="same_kind")

raise ValueError(
"n_classes from the pickle has an incompatible dtype:\n"
f"- expected: {expected_dtype}\n"
f"- got: {n_classes.dtype}"
)


def _check_value_ndarray(value_ndarray, expected_dtype, expected_shape):
if value_ndarray.shape != expected_shape:
raise ValueError(
"Wrong shape for value array from the pickle: "
f"expected {expected_shape}, got {value_ndarray.shape}"
)

if not value_ndarray.flags.c_contiguous:
raise ValueError(
"value array from the pickle should be a C-contiguous array"
)

if value_ndarray.dtype == expected_dtype:
return value_ndarray

# Handles different endianness
if value_ndarray.dtype.str.endswith('f8'):
return value_ndarray.astype(expected_dtype, casting='equiv')

raise ValueError(
"value array from the pickle has an incompatible dtype:\n"
f"- expected: {expected_dtype}\n"
f"- got: {value_ndarray.dtype}"
)


def _dtype_to_dict(dtype):
return {name: dt.str for name, (dt, *rest) in dtype.fields.items()}


def _dtype_dict_with_modified_bitness(dtype_dict):
# field names in Node struct with SIZE_t types (see sklearn/tree/_tree.pxd)
indexing_field_names = ["left_child", "right_child", "feature", "n_node_samples"]

expected_dtype_size = str(struct.calcsize("P"))
allowed_dtype_size = "8" if expected_dtype_size == "4" else "4"

allowed_dtype_dict = dtype_dict.copy()
for name in indexing_field_names:
allowed_dtype_dict[name] = allowed_dtype_dict[name].replace(
expected_dtype_size, allowed_dtype_size
)

return allowed_dtype_dict


def _all_compatible_dtype_dicts(dtype):
# The Cython code for decision trees uses platform-specific SIZE_t
# typed indexing fields that correspond to either i4 or i8 dtypes for
# the matching fields in the numpy array depending on the bitness of
# the platform (32 bit or 64 bit respectively).
#
# We need to cast the indexing fields of the NODE_DTYPE-dtyped array at
# pickle load time to enable cross-bitness deployment scenarios. We
# typically want to make it possible to run the expensive fit method of
# a tree estimator on a 64 bit server platform, pickle the estimator
# for deployment and run the predict method of a low power 32 bit edge
# platform.
#
# A similar thing happens for endianness, the machine where the pickle was
# saved can have a different endianness than the machine where the pickle
# is loaded

dtype_dict = _dtype_to_dict(dtype)
dtype_dict_with_modified_bitness = _dtype_dict_with_modified_bitness(dtype_dict)
dtype_dict_with_modified_endianness = _dtype_to_dict(dtype.newbyteorder())
dtype_dict_with_modified_bitness_and_endianness = _dtype_dict_with_modified_bitness(
dtype_dict_with_modified_endianness
)

return [
dtype_dict,
dtype_dict_with_modified_bitness,
dtype_dict_with_modified_endianness,
dtype_dict_with_modified_bitness_and_endianness,
]


def _check_node_ndarray(node_ndarray, expected_dtype):
if node_ndarray.ndim != 1:
raise ValueError(
"Wrong dimensions for node array from the pickle: "
f"expected 1, got {node_ndarray.ndim}"
)

if not node_ndarray.flags.c_contiguous:
raise ValueError(
"node array from the pickle should be a C-contiguous array"
)

node_ndarray_dtype = node_ndarray.dtype
if node_ndarray_dtype == expected_dtype:
return node_ndarray

node_ndarray_dtype_dict = _dtype_to_dict(node_ndarray_dtype)
all_compatible_dtype_dicts = _all_compatible_dtype_dicts(expected_dtype)

if node_ndarray_dtype_dict not in all_compatible_dtype_dicts:
raise ValueError(
"node array from the pickle has an incompatible dtype:\n"
f"- expected: {expected_dtype}\n"
f"- got : {node_ndarray_dtype}"
)

return node_ndarray.astype(expected_dtype, casting="same_kind")


# =============================================================================
# Build Pruned Tree
# =============================================================================
Expand Down
Loading