Skip to content

Commit 41a6458

Browse files
bpo-40408: Fix support of nested type variables in GenericAlias. (GH-19836)
1 parent 603d354 commit 41a6458

File tree

3 files changed

+140
-26
lines changed

3 files changed

+140
-26
lines changed

Lib/test/test_genericalias.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141

4242
from typing import TypeVar
4343
T = TypeVar('T')
44+
K = TypeVar('K')
45+
V = TypeVar('V')
4446

4547
class BaseTest(unittest.TestCase):
4648
"""Test basics."""
@@ -170,10 +172,7 @@ def test_exposed_type(self):
170172
self.assertEqual(a.__parameters__, ())
171173

172174
def test_parameters(self):
173-
from typing import TypeVar
174-
T = TypeVar('T')
175-
K = TypeVar('K')
176-
V = TypeVar('V')
175+
from typing import List, Dict, Callable
177176
D0 = dict[str, int]
178177
self.assertEqual(D0.__args__, (str, int))
179178
self.assertEqual(D0.__parameters__, ())
@@ -195,14 +194,43 @@ def test_parameters(self):
195194
L1 = list[T]
196195
self.assertEqual(L1.__args__, (T,))
197196
self.assertEqual(L1.__parameters__, (T,))
197+
L2 = list[list[T]]
198+
self.assertEqual(L2.__args__, (list[T],))
199+
self.assertEqual(L2.__parameters__, (T,))
200+
L3 = list[List[T]]
201+
self.assertEqual(L3.__args__, (List[T],))
202+
self.assertEqual(L3.__parameters__, (T,))
203+
L4a = list[Dict[K, V]]
204+
self.assertEqual(L4a.__args__, (Dict[K, V],))
205+
self.assertEqual(L4a.__parameters__, (K, V))
206+
L4b = list[Dict[T, int]]
207+
self.assertEqual(L4b.__args__, (Dict[T, int],))
208+
self.assertEqual(L4b.__parameters__, (T,))
209+
L5 = list[Callable[[K, V], K]]
210+
self.assertEqual(L5.__args__, (Callable[[K, V], K],))
211+
self.assertEqual(L5.__parameters__, (K, V))
198212

199213
def test_parameter_chaining(self):
200-
from typing import TypeVar
201-
T = TypeVar('T')
214+
from typing import List, Dict, Union, Callable
202215
self.assertEqual(list[T][int], list[int])
203216
self.assertEqual(dict[str, T][int], dict[str, int])
204217
self.assertEqual(dict[T, int][str], dict[str, int])
218+
self.assertEqual(dict[K, V][str, int], dict[str, int])
205219
self.assertEqual(dict[T, T][int], dict[int, int])
220+
221+
self.assertEqual(list[list[T]][int], list[list[int]])
222+
self.assertEqual(list[dict[T, int]][str], list[dict[str, int]])
223+
self.assertEqual(list[dict[str, T]][int], list[dict[str, int]])
224+
self.assertEqual(list[dict[K, V]][str, int], list[dict[str, int]])
225+
self.assertEqual(dict[T, list[int]][str], dict[str, list[int]])
226+
227+
self.assertEqual(list[List[T]][int], list[List[int]])
228+
self.assertEqual(list[Dict[K, V]][str, int], list[Dict[str, int]])
229+
self.assertEqual(list[Union[K, V]][str, int], list[Union[str, int]])
230+
self.assertEqual(list[Callable[[K, V], K]][str, int],
231+
list[Callable[[str, int], str]])
232+
self.assertEqual(dict[T, List[int]][str], dict[str, List[int]])
233+
206234
with self.assertRaises(TypeError):
207235
list[int][int]
208236
dict[T, int][str, int]
@@ -255,7 +283,6 @@ def test_union(self):
255283
self.assertEqual(a.__parameters__, ())
256284

257285
def test_union_generic(self):
258-
T = typing.TypeVar('T')
259286
a = typing.Union[list[T], tuple[T, ...]]
260287
self.assertEqual(a.__args__, (list[T], tuple[T, ...]))
261288
self.assertEqual(a.__parameters__, (T,))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fixed support of nested type variables in GenericAlias (e.g.
2+
``list[list[T]]``).

Objects/genericaliasobject.c

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -182,28 +182,60 @@ tuple_index(PyObject *self, Py_ssize_t len, PyObject *item)
182182
return -1;
183183
}
184184

185-
// tuple(t for t in args if isinstance(t, TypeVar))
185+
static int
186+
tuple_add(PyObject *self, Py_ssize_t len, PyObject *item)
187+
{
188+
if (tuple_index(self, len, item) < 0) {
189+
Py_INCREF(item);
190+
PyTuple_SET_ITEM(self, len, item);
191+
return 1;
192+
}
193+
return 0;
194+
}
195+
186196
static PyObject *
187197
make_parameters(PyObject *args)
188198
{
189-
Py_ssize_t len = PyTuple_GET_SIZE(args);
199+
Py_ssize_t nargs = PyTuple_GET_SIZE(args);
200+
Py_ssize_t len = nargs;
190201
PyObject *parameters = PyTuple_New(len);
191202
if (parameters == NULL)
192203
return NULL;
193204
Py_ssize_t iparam = 0;
194-
for (Py_ssize_t iarg = 0; iarg < len; iarg++) {
205+
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
195206
PyObject *t = PyTuple_GET_ITEM(args, iarg);
196207
int typevar = is_typevar(t);
197208
if (typevar < 0) {
198-
Py_XDECREF(parameters);
209+
Py_DECREF(parameters);
199210
return NULL;
200211
}
201212
if (typevar) {
202-
if (tuple_index(parameters, iparam, t) < 0) {
203-
Py_INCREF(t);
204-
PyTuple_SET_ITEM(parameters, iparam, t);
205-
iparam++;
213+
iparam += tuple_add(parameters, iparam, t);
214+
}
215+
else {
216+
_Py_IDENTIFIER(__parameters__);
217+
PyObject *subparams;
218+
if (_PyObject_LookupAttrId(t, &PyId___parameters__, &subparams) < 0) {
219+
Py_DECREF(parameters);
220+
return NULL;
221+
}
222+
if (subparams && PyTuple_Check(subparams)) {
223+
Py_ssize_t len2 = PyTuple_GET_SIZE(subparams);
224+
Py_ssize_t needed = len2 - 1 - (iarg - iparam);
225+
if (needed > 0) {
226+
len += needed;
227+
if (_PyTuple_Resize(&parameters, len) < 0) {
228+
Py_DECREF(subparams);
229+
Py_DECREF(parameters);
230+
return NULL;
231+
}
232+
}
233+
for (Py_ssize_t j = 0; j < len2; j++) {
234+
PyObject *t2 = PyTuple_GET_ITEM(subparams, j);
235+
iparam += tuple_add(parameters, iparam, t2);
236+
}
206237
}
238+
Py_XDECREF(subparams);
207239
}
208240
}
209241
if (iparam < len) {
@@ -215,6 +247,48 @@ make_parameters(PyObject *args)
215247
return parameters;
216248
}
217249

250+
/* If obj is a generic alias, substitute type variables params
251+
with substitutions argitems. For example, if obj is list[T],
252+
params is (T, S), and argitems is (str, int), return list[str].
253+
If obj doesn't have a __parameters__ attribute or that's not
254+
a non-empty tuple, return a new reference to obj. */
255+
static PyObject *
256+
subs_tvars(PyObject *obj, PyObject *params, PyObject **argitems)
257+
{
258+
_Py_IDENTIFIER(__parameters__);
259+
PyObject *subparams;
260+
if (_PyObject_LookupAttrId(obj, &PyId___parameters__, &subparams) < 0) {
261+
return NULL;
262+
}
263+
if (subparams && PyTuple_Check(subparams) && PyTuple_GET_SIZE(subparams)) {
264+
Py_ssize_t nparams = PyTuple_GET_SIZE(params);
265+
Py_ssize_t nsubargs = PyTuple_GET_SIZE(subparams);
266+
PyObject *subargs = PyTuple_New(nsubargs);
267+
if (subargs == NULL) {
268+
Py_DECREF(subparams);
269+
return NULL;
270+
}
271+
for (Py_ssize_t i = 0; i < nsubargs; ++i) {
272+
PyObject *arg = PyTuple_GET_ITEM(subparams, i);
273+
Py_ssize_t iparam = tuple_index(params, nparams, arg);
274+
if (iparam >= 0) {
275+
arg = argitems[iparam];
276+
}
277+
Py_INCREF(arg);
278+
PyTuple_SET_ITEM(subargs, i, arg);
279+
}
280+
281+
obj = PyObject_GetItem(obj, subargs);
282+
283+
Py_DECREF(subargs);
284+
}
285+
else {
286+
Py_INCREF(obj);
287+
}
288+
Py_XDECREF(subparams);
289+
return obj;
290+
}
291+
218292
static PyObject *
219293
ga_getitem(PyObject *self, PyObject *item)
220294
{
@@ -233,17 +307,25 @@ ga_getitem(PyObject *self, PyObject *item)
233307
self);
234308
}
235309
int is_tuple = PyTuple_Check(item);
236-
Py_ssize_t nitem = is_tuple ? PyTuple_GET_SIZE(item) : 1;
237-
if (nitem != nparams) {
310+
Py_ssize_t nitems = is_tuple ? PyTuple_GET_SIZE(item) : 1;
311+
PyObject **argitems = is_tuple ? &PyTuple_GET_ITEM(item, 0) : &item;
312+
if (nitems != nparams) {
238313
return PyErr_Format(PyExc_TypeError,
239314
"Too %s arguments for %R",
240-
nitem > nparams ? "many" : "few",
315+
nitems > nparams ? "many" : "few",
241316
self);
242317
}
318+
/* Replace all type variables (specified by alias->parameters)
319+
with corresponding values specified by argitems.
320+
t = list[T]; t[int] -> newargs = [int]
321+
t = dict[str, T]; t[int] -> newargs = [str, int]
322+
t = dict[T, list[S]]; t[str, int] -> newargs = [str, list[int]]
323+
*/
243324
Py_ssize_t nargs = PyTuple_GET_SIZE(alias->args);
244325
PyObject *newargs = PyTuple_New(nargs);
245-
if (newargs == NULL)
326+
if (newargs == NULL) {
246327
return NULL;
328+
}
247329
for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) {
248330
PyObject *arg = PyTuple_GET_ITEM(alias->args, iarg);
249331
int typevar = is_typevar(arg);
@@ -254,18 +336,21 @@ ga_getitem(PyObject *self, PyObject *item)
254336
if (typevar) {
255337
Py_ssize_t iparam = tuple_index(alias->parameters, nparams, arg);
256338
assert(iparam >= 0);
257-
if (is_tuple) {
258-
arg = PyTuple_GET_ITEM(item, iparam);
259-
}
260-
else {
261-
assert(iparam == 0);
262-
arg = item;
339+
arg = argitems[iparam];
340+
Py_INCREF(arg);
341+
}
342+
else {
343+
arg = subs_tvars(arg, alias->parameters, argitems);
344+
if (arg == NULL) {
345+
Py_DECREF(newargs);
346+
return NULL;
263347
}
264348
}
265-
Py_INCREF(arg);
266349
PyTuple_SET_ITEM(newargs, iarg, arg);
267350
}
351+
268352
PyObject *res = Py_GenericAlias(alias->origin, newargs);
353+
269354
Py_DECREF(newargs);
270355
return res;
271356
}

0 commit comments

Comments
 (0)