@@ -62,46 +62,24 @@ static PyObject *
6262PyArray_GetObjectToGenericCastingImpl (void );
6363
6464
65- /**
66- * Fetch the casting implementation from one DType to another.
67- *
68- * @param from The implementation to cast from
69- * @param to The implementation to cast to
70- *
71- * @returns A castingimpl (PyArrayDTypeMethod *), None or NULL with an
72- * error set.
73- */
74- NPY_NO_EXPORT PyObject *
75- PyArray_GetCastingImpl (PyArray_DTypeMeta * from , PyArray_DTypeMeta * to )
65+ static PyObject *
66+ create_casting_impl (PyArray_DTypeMeta * from , PyArray_DTypeMeta * to )
7667{
77- PyObject * res ;
78- if (from == to ) {
79- res = (PyObject * )NPY_DT_SLOTS (from )-> within_dtype_castingimpl ;
80- }
81- else {
82- res = PyDict_GetItemWithError (NPY_DT_SLOTS (from )-> castingimpls , (PyObject * )to );
83- }
84- if (res != NULL || PyErr_Occurred ()) {
85- Py_XINCREF (res );
86- return res ;
87- }
8868 /*
89- * The following code looks up CastingImpl based on the fact that anything
69+ * Look up CastingImpl based on the fact that anything
9070 * can be cast to and from objects or structured (void) dtypes.
91- *
92- * The last part adds casts dynamically based on legacy definition
9371 */
9472 if (from -> type_num == NPY_OBJECT ) {
95- res = PyArray_GetObjectToGenericCastingImpl ();
73+ return PyArray_GetObjectToGenericCastingImpl ();
9674 }
9775 else if (to -> type_num == NPY_OBJECT ) {
98- res = PyArray_GetGenericToObjectCastingImpl ();
76+ return PyArray_GetGenericToObjectCastingImpl ();
9977 }
10078 else if (from -> type_num == NPY_VOID ) {
101- res = PyArray_GetVoidToGenericCastingImpl ();
79+ return PyArray_GetVoidToGenericCastingImpl ();
10280 }
10381 else if (to -> type_num == NPY_VOID ) {
104- res = PyArray_GetGenericToVoidCastingImpl ();
82+ return PyArray_GetGenericToVoidCastingImpl ();
10583 }
10684 /*
10785 * Reject non-legacy dtypes. They need to use the new API to add casts and
@@ -125,42 +103,105 @@ PyArray_GetCastingImpl(PyArray_DTypeMeta *from, PyArray_DTypeMeta *to)
125103 from -> singleton , to -> type_num );
126104 if (castfunc == NULL ) {
127105 PyErr_Clear ();
128- /* Remember that this cast is not possible */
129- if (PyDict_SetItem (NPY_DT_SLOTS (from )-> castingimpls ,
130- (PyObject * ) to , Py_None ) < 0 ) {
131- return NULL ;
132- }
133106 Py_RETURN_NONE ;
134107 }
135108 }
136-
137- /* PyArray_AddLegacyWrapping_CastingImpl find the correct casting level: */
138- /*
139- * TODO: Possibly move this to the cast registration time. But if we do
140- * that, we have to also update the cast when the casting safety
141- * is registered.
109+ /* Create a cast using the state of the legacy casting setup defined
110+ * during the setup of the DType.
111+ *
112+ * Ideally we would do this when we create the DType, but legacy user
113+ * DTypes don't have a way to signal that a DType is done setting up
114+ * casts. Without such a mechanism, the safest way to know that a
115+ * DType is done setting up is to register the cast lazily the first
116+ * time a user does the cast.
117+ *
118+ * We *could* register the casts when we create the wrapping
119+ * DTypeMeta, but that means the internals of the legacy user DType
120+ * system would need to update the state of the casting safety flags
121+ * in the cast implementations stored on the DTypeMeta. That's an
122+ * inversion of abstractions and would be tricky to do without
123+ * creating circular dependencies inside NumPy.
142124 */
143125 if (PyArray_AddLegacyWrapping_CastingImpl (from , to , -1 ) < 0 ) {
144126 return NULL ;
145127 }
128+ /* castingimpls is unconditionally filled by
129+ * AddLegacyWrapping_CastingImpl, so this won't create a recursive
130+ * critical section
131+ */
146132 return PyArray_GetCastingImpl (from , to );
147133 }
134+ }
148135
149- if (res == NULL ) {
136+ static PyObject *
137+ ensure_castingimpl_exists (PyArray_DTypeMeta * from , PyArray_DTypeMeta * to )
138+ {
139+ int return_error = 0 ;
140+ PyObject * res = NULL ;
141+
142+ /* Need to create the cast. This might happen at runtime so we enter a
143+ critical section to avoid races */
144+
145+ Py_BEGIN_CRITICAL_SECTION (NPY_DT_SLOTS (from )-> castingimpls );
146+
147+ /* check if another thread filled it while this thread was blocked on
148+ acquiring the critical section */
149+ if (PyDict_GetItemRef (NPY_DT_SLOTS (from )-> castingimpls , (PyObject * )to ,
150+ & res ) < 0 ) {
151+ return_error = 1 ;
152+ }
153+ else if (res == NULL ) {
154+ res = create_casting_impl (from , to );
155+ if (res == NULL ) {
156+ return_error = 1 ;
157+ }
158+ else if (PyDict_SetItem (NPY_DT_SLOTS (from )-> castingimpls ,
159+ (PyObject * )to , res ) < 0 ) {
160+ return_error = 1 ;
161+ }
162+ }
163+ Py_END_CRITICAL_SECTION ();
164+ if (return_error ) {
165+ Py_XDECREF (res );
150166 return NULL ;
151167 }
152- if (from == to ) {
168+ if (from == to && res == Py_None ) {
153169 PyErr_Format (PyExc_RuntimeError ,
154170 "Internal NumPy error, within-DType cast missing for %S!" , from );
155171 Py_DECREF (res );
156172 return NULL ;
157173 }
158- if (PyDict_SetItem (NPY_DT_SLOTS (from )-> castingimpls ,
159- (PyObject * )to , res ) < 0 ) {
160- Py_DECREF (res );
174+ return res ;
175+ }
176+
177+ /**
178+ * Fetch the casting implementation from one DType to another.
179+ *
180+ * @param from The implementation to cast from
181+ * @param to The implementation to cast to
182+ *
183+ * @returns A castingimpl (PyArrayDTypeMethod *), None or NULL with an
184+ * error set.
185+ */
186+ NPY_NO_EXPORT PyObject *
187+ PyArray_GetCastingImpl (PyArray_DTypeMeta * from , PyArray_DTypeMeta * to )
188+ {
189+ PyObject * res = NULL ;
190+ if (from == to ) {
191+ if ((NPY_DT_SLOTS (from )-> within_dtype_castingimpl ) != NULL ) {
192+ res = Py_XNewRef (
193+ (PyObject * )NPY_DT_SLOTS (from )-> within_dtype_castingimpl );
194+ }
195+ }
196+ else if (PyDict_GetItemRef (NPY_DT_SLOTS (from )-> castingimpls ,
197+ (PyObject * )to , & res ) < 0 ) {
161198 return NULL ;
162199 }
163- return res ;
200+ if (res != NULL ) {
201+ return res ;
202+ }
203+
204+ return ensure_castingimpl_exists (from , to );
164205}
165206
166207
@@ -409,7 +450,7 @@ _get_cast_safety_from_castingimpl(PyArrayMethodObject *castingimpl,
409450 * implementations fully to have them available for doing the actual cast
410451 * later.
411452 *
412- * @param from The descriptor to cast from
453+ * @param from The descriptor to cast from
413454 * @param to The descriptor to cast to (may be NULL)
414455 * @param to_dtype If `to` is NULL, must pass the to_dtype (otherwise this
415456 * is ignored).
@@ -2031,6 +2072,11 @@ PyArray_AddCastingImplementation(PyBoundArrayMethodObject *meth)
20312072/**
20322073 * Add a new casting implementation using a PyArrayMethod_Spec.
20332074 *
2075+ * Using this function outside of module initialization without holding a
2076+ * critical section on the castingimpls dict may lead to a race to fill the
2077+ * dict. Use PyArray_GetGastingImpl to lazily register casts at runtime
2078+ * safely.
2079+ *
20342080 * @param spec The specification to use as a source
20352081 * @param private If private, allow slots not publicly exposed.
20362082 * @return 0 on success -1 on failure
0 commit comments