Skip to content

Commit 4547234

Browse files
committed
ffi::numpy
1 parent 0d4a5ad commit 4547234

20 files changed

Lines changed: 990 additions & 730 deletions

File tree

src/ffi/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod buffer;
88
mod bytes;
99
pub(crate) mod compat;
1010
mod fragment;
11+
mod numpy;
1112
mod pyboolref;
1213
#[cfg(all(CPython, not(Py_GIL_DISABLED)))]
1314
mod pybytearrayref;
@@ -28,6 +29,13 @@ mod pytupleref;
2829
mod pyuuidref;
2930
mod utf8;
3031

32+
pub(crate) use numpy::{
33+
NPY_ARRAY_C_CONTIGUOUS, NPY_ARRAY_NOTSWAPPED, NumpyBool, NumpyDateTimeError, NumpyDatetime64,
34+
NumpyDatetime64Repr, NumpyDatetimeUnit, NumpyFloat16, NumpyFloat32, NumpyFloat64, NumpyInt8,
35+
NumpyInt16, NumpyInt32, NumpyInt64, NumpyUint8, NumpyUint16, NumpyUint32, NumpyUint64,
36+
PyArrayInterface, PyCapsule,
37+
};
38+
3139
pub(crate) use compat::*;
3240

3341
#[allow(unused_imports)]

src/ffi/numpy/array.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// SPDX-License-Identifier: MPL-2.0
2+
// Copyright ijl (2020-2026)
3+
4+
use crate::ffi::{Py_intptr_t, PyObject};
5+
use core::ffi::{c_char, c_int, c_void};
6+
7+
#[repr(C)]
8+
pub(crate) struct PyCapsule {
9+
head: PyObject,
10+
pub pointer: *mut c_void,
11+
pub name: *const c_char,
12+
pub context: *mut c_void,
13+
pub destructor: *mut c_void, // should be typedef void (*PyCapsule_Destructor)(PyObject *);
14+
}
15+
16+
// https://docs.scipy.org/doc/numpy/reference/arrays.interface.html#c.__array_struct__
17+
18+
pub(crate) const NPY_ARRAY_C_CONTIGUOUS: c_int = 0x1;
19+
pub(crate) const NPY_ARRAY_NOTSWAPPED: c_int = 0x200;
20+
21+
#[repr(C)]
22+
pub(crate) struct PyArrayInterface {
23+
pub two: c_int,
24+
pub nd: c_int,
25+
pub typekind: c_char,
26+
pub itemsize: c_int,
27+
pub flags: c_int,
28+
pub shape: *mut Py_intptr_t,
29+
pub strides: *mut Py_intptr_t,
30+
pub data: *mut c_void,
31+
pub descr: *mut PyObject,
32+
}

src/ffi/numpy/datetime.rs

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
2+
// Copyright ijl (2022-2026), Ben Sully (2021)
3+
4+
use crate::ffi::{PyListRef, PyObject, PyStrRef, PyTupleRef};
5+
use crate::opt::Opt;
6+
use crate::typeref::{DESCR_STR, DTYPE_STR};
7+
use jiff::Timestamp;
8+
use jiff::civil::DateTime;
9+
10+
/// This mimicks the units supported by numpy's datetime64 type.
11+
///
12+
/// See
13+
/// https://github.com/numpy/numpy/blob/fc8e3bbe419748ac5c6b7f3d0845e4bafa74644b/numpy/core/include/numpy/ndarraytypes.h#L268-L282.
14+
#[derive(Clone, Copy, PartialEq)]
15+
pub(crate) enum NumpyDatetimeUnit {
16+
NaT,
17+
Years,
18+
Months,
19+
Weeks,
20+
Days,
21+
Hours,
22+
Minutes,
23+
Seconds,
24+
Milliseconds,
25+
Microseconds,
26+
Nanoseconds,
27+
Picoseconds,
28+
Femtoseconds,
29+
Attoseconds,
30+
Generic,
31+
}
32+
33+
impl NumpyDatetimeUnit {
34+
#[cold]
35+
pub const fn as_str(self) -> &'static str {
36+
match self {
37+
Self::NaT => "NaT",
38+
Self::Years => "years",
39+
Self::Months => "months",
40+
Self::Weeks => "weeks",
41+
Self::Days => "days",
42+
Self::Hours => "hours",
43+
Self::Minutes => "minutes",
44+
Self::Seconds => "seconds",
45+
Self::Milliseconds => "milliseconds",
46+
Self::Microseconds => "microseconds",
47+
Self::Nanoseconds => "nanoseconds",
48+
Self::Picoseconds => "picoseconds",
49+
Self::Femtoseconds => "femtoseconds",
50+
Self::Attoseconds => "attoseconds",
51+
Self::Generic => "generic",
52+
}
53+
}
54+
}
55+
56+
#[derive(Clone, Copy)]
57+
pub(crate) enum NumpyDateTimeError {
58+
UnsupportedUnit(NumpyDatetimeUnit),
59+
Unrepresentable { unit: NumpyDatetimeUnit, val: i64 },
60+
}
61+
62+
macro_rules! to_jiff_datetime {
63+
($timestamp:expr, $self:expr, $val:expr) => {
64+
Ok(
65+
($timestamp.map_err(|_| NumpyDateTimeError::Unrepresentable {
66+
unit: $self,
67+
val: $val,
68+
})?)
69+
.to_zoned(jiff::tz::TimeZone::UTC)
70+
.datetime(),
71+
)
72+
};
73+
}
74+
75+
impl NumpyDatetimeUnit {
76+
/// Create a `NumpyDatetimeUnit` from a pointer to a Python object holding a
77+
/// numpy array.
78+
///
79+
/// This function must only be called with pointers to numpy arrays.
80+
///
81+
/// We need to look inside the `obj.dtype.descr` attribute of the Python
82+
/// object rather than using the `descr` field of the `__array_struct__`
83+
/// because that field isn't populated for datetime64 arrays; see
84+
/// https://github.com/numpy/numpy/issues/5350.
85+
#[cold]
86+
#[inline(never)]
87+
pub fn from_pyobject(ptr: *mut PyObject) -> Self {
88+
let dtype = ffi!(PyObject_GetAttr(ptr, DTYPE_STR));
89+
let descr = ffi!(PyObject_GetAttr(dtype, DESCR_STR));
90+
let el0 = unsafe { PyListRef::from_ptr_unchecked(descr).get(0) };
91+
let descr_str = unsafe { PyTupleRef::from_ptr_unchecked(el0).get(1) };
92+
match PyStrRef::from_ptr(descr_str) {
93+
Ok(uni) => {
94+
match uni.as_str() {
95+
Some(as_str) => {
96+
if as_str.len() < 5 {
97+
return Self::NaT;
98+
}
99+
// unit descriptions are found at
100+
// https://github.com/numpy/numpy/blob/b235f9e701e14ed6f6f6dcba885f7986a833743f/numpy/core/src/multiarray/datetime.c#L79-L96.
101+
let ret = match &as_str[4..as_str.len() - 1] {
102+
"Y" => Self::Years,
103+
"M" => Self::Months,
104+
"W" => Self::Weeks,
105+
"D" => Self::Days,
106+
"h" => Self::Hours,
107+
"m" => Self::Minutes,
108+
"s" => Self::Seconds,
109+
"ms" => Self::Milliseconds,
110+
"us" => Self::Microseconds,
111+
"ns" => Self::Nanoseconds,
112+
"ps" => Self::Picoseconds,
113+
"fs" => Self::Femtoseconds,
114+
"as" => Self::Attoseconds,
115+
"generic" => Self::Generic,
116+
_ => unreachable!(),
117+
};
118+
ffi!(Py_DECREF(dtype));
119+
ffi!(Py_DECREF(descr));
120+
ret
121+
}
122+
None => Self::NaT,
123+
}
124+
}
125+
Err(_) => Self::NaT,
126+
}
127+
}
128+
129+
#[cold]
130+
#[cfg_attr(feature = "optimize", optimize(size))]
131+
pub fn datetime(self, val: i64, opts: Opt) -> Result<NumpyDatetime64Repr, NumpyDateTimeError> {
132+
let datetime = match self {
133+
Self::Years => {
134+
let year = val + 1970;
135+
if !(0..=9999).contains(&year) {
136+
cold_path!();
137+
return Err(NumpyDateTimeError::Unrepresentable { unit: self, val });
138+
} else {
139+
Ok(DateTime::new(year as i16, 1, 1, 0, 0, 0, 0).unwrap())
140+
}
141+
}
142+
Self::Months => {
143+
let year = val / 12 + 1970;
144+
let month = val % 12 + 1;
145+
if !(0..=9999).contains(&year) || !(0..=12).contains(&month) {
146+
cold_path!();
147+
return Err(NumpyDateTimeError::Unrepresentable { unit: self, val });
148+
} else {
149+
Ok(DateTime::new(year as i16, month as i8, 1, 0, 0, 0, 0).unwrap())
150+
}
151+
}
152+
Self::Weeks => {
153+
to_jiff_datetime!(Timestamp::from_second(val * 7 * 24 * 60 * 60), self, val)
154+
}
155+
Self::Days => to_jiff_datetime!(Timestamp::from_second(val * 24 * 60 * 60), self, val),
156+
Self::Hours => to_jiff_datetime!(Timestamp::from_second(val * 60 * 60), self, val),
157+
Self::Minutes => to_jiff_datetime!(Timestamp::from_second(val * 60), self, val),
158+
Self::Seconds => to_jiff_datetime!(Timestamp::from_second(val), self, val),
159+
Self::Milliseconds => to_jiff_datetime!(Timestamp::from_millisecond(val), self, val),
160+
Self::Microseconds => to_jiff_datetime!(Timestamp::from_microsecond(val), self, val),
161+
Self::Nanoseconds => {
162+
to_jiff_datetime!(Timestamp::from_nanosecond(i128::from(val)), self, val)
163+
}
164+
_ => Err(NumpyDateTimeError::UnsupportedUnit(self)),
165+
};
166+
match datetime {
167+
Ok(dt) => match dt.year() {
168+
0..=9999 => Ok(NumpyDatetime64Repr { dt, opts }),
169+
_ => Err(NumpyDateTimeError::Unrepresentable { unit: self, val }),
170+
},
171+
Err(err) => Err(err),
172+
}
173+
}
174+
}
175+
176+
macro_rules! forward_inner {
177+
($meth: ident, $ty: ident) => {
178+
pub fn $meth(&self) -> $ty {
179+
debug_assert!(self.dt.$meth() >= 0);
180+
#[allow(clippy::cast_sign_loss)]
181+
let ret = self.dt.$meth() as $ty; // stmt_expr_attributes
182+
ret
183+
}
184+
};
185+
}
186+
187+
pub(crate) struct NumpyDatetime64Repr {
188+
pub dt: DateTime,
189+
pub opts: Opt,
190+
}
191+
192+
impl NumpyDatetime64Repr {
193+
forward_inner!(year, i32);
194+
forward_inner!(month, u8);
195+
forward_inner!(day, u8);
196+
forward_inner!(hour, u8);
197+
forward_inner!(minute, u8);
198+
forward_inner!(second, u8);
199+
200+
pub fn nanosecond(&self) -> u32 {
201+
debug_assert!(self.dt.subsec_nanosecond() >= 0);
202+
self.dt.subsec_nanosecond().cast_unsigned()
203+
}
204+
205+
pub fn microsecond(&self) -> u32 {
206+
self.nanosecond() / 1_000
207+
}
208+
}

src/ffi/numpy/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// SPDX-License-Identifier: MPL-2.0
2+
// Copyright ijl (2026)
3+
4+
mod array;
5+
mod datetime;
6+
mod scalar;
7+
8+
pub(crate) use array::{NPY_ARRAY_C_CONTIGUOUS, NPY_ARRAY_NOTSWAPPED, PyArrayInterface, PyCapsule};
9+
pub(crate) use datetime::{NumpyDateTimeError, NumpyDatetime64Repr, NumpyDatetimeUnit};
10+
pub(crate) use scalar::{
11+
NumpyBool, NumpyDatetime64, NumpyFloat16, NumpyFloat32, NumpyFloat64, NumpyInt8, NumpyInt16,
12+
NumpyInt32, NumpyInt64, NumpyUint8, NumpyUint16, NumpyUint32, NumpyUint64,
13+
};

src/ffi/numpy/scalar.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// SPDX-License-Identifier: MPL-2.0
2+
// Copyright ijl (2026)
3+
4+
use crate::ffi::PyObject;
5+
6+
#[repr(C)]
7+
pub(crate) struct NumpyFloat64 {
8+
head: PyObject,
9+
pub value: f64,
10+
}
11+
12+
#[repr(C)]
13+
pub(crate) struct NumpyFloat32 {
14+
head: PyObject,
15+
pub value: f32,
16+
}
17+
18+
#[repr(C)]
19+
pub(crate) struct NumpyFloat16 {
20+
head: PyObject,
21+
pub value: u16,
22+
}
23+
24+
#[repr(C)]
25+
pub(crate) struct NumpyUint64 {
26+
head: PyObject,
27+
pub value: u64,
28+
}
29+
30+
#[repr(C)]
31+
pub(crate) struct NumpyUint32 {
32+
head: PyObject,
33+
pub value: u32,
34+
}
35+
36+
#[repr(C)]
37+
pub(crate) struct NumpyUint16 {
38+
head: PyObject,
39+
pub value: u16,
40+
}
41+
42+
#[repr(C)]
43+
pub(crate) struct NumpyUint8 {
44+
head: PyObject,
45+
pub value: u8,
46+
}
47+
48+
#[repr(C)]
49+
pub(crate) struct NumpyInt64 {
50+
head: PyObject,
51+
pub value: i64,
52+
}
53+
54+
#[repr(C)]
55+
pub(crate) struct NumpyInt32 {
56+
head: PyObject,
57+
pub value: i32,
58+
}
59+
60+
#[repr(C)]
61+
pub(crate) struct NumpyInt16 {
62+
head: PyObject,
63+
pub value: i16,
64+
}
65+
66+
#[repr(C)]
67+
pub(crate) struct NumpyInt8 {
68+
head: PyObject,
69+
pub value: i8,
70+
}
71+
72+
#[repr(C)]
73+
pub(crate) struct NumpyBool {
74+
head: PyObject,
75+
pub value: bool,
76+
}
77+
78+
#[repr(C)]
79+
pub(crate) struct NumpyDatetime64 {
80+
head: PyObject,
81+
pub value: i64,
82+
}

src/ffi/pystrref/object.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub(crate) enum PyStrRefError {
5353
NotStrType,
5454
}
5555

56-
#[derive(Clone)]
56+
#[derive(Copy, Clone)]
5757
#[repr(transparent)]
5858
pub(crate) struct PyStrRef {
5959
ptr: core::ptr::NonNull<pyo3_ffi::PyObject>,

src/serialize/error.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: MPL-2.0
2-
// Copyright ijl (2021-2025)
2+
// Copyright ijl (2021-2026)
33

4+
use crate::ffi::PyStrRef;
45
use core::ffi::CStr;
56
use core::ptr::NonNull;
67

@@ -20,6 +21,7 @@ pub(crate) enum SerializeError {
2021
NumpyNotCContiguous,
2122
NumpyNotNativeEndian,
2223
NumpyUnsupportedDatatype,
24+
NumpyUnsupportedDatetimeUnit(PyStrRef),
2325
UnsupportedType(NonNull<crate::ffi::PyObject>),
2426
}
2527

@@ -61,6 +63,9 @@ impl core::fmt::Display for SerializeError {
6163
SerializeError::NumpyUnsupportedDatatype => {
6264
write!(f, "unsupported datatype in numpy array")
6365
}
66+
SerializeError::NumpyUnsupportedDatetimeUnit(msg) => {
67+
write!(f, "{}", msg.as_str().unwrap())
68+
}
6469
SerializeError::UnsupportedType(ptr) => {
6570
let name =
6671
unsafe { CStr::from_ptr((*ob_type!(ptr.as_ptr())).tp_name).to_string_lossy() };

0 commit comments

Comments
 (0)