Skip to content

Commit a2aea77

Browse files
committed
BUG: Don't construct formatters until we're sure they're correct
Previously, formatters could incur errors from being run on object arrays, even though the formatter was not used.
1 parent 3b2a7a7 commit a2aea77

File tree

1 file changed

+32
-27
lines changed

1 file changed

+32
-27
lines changed

numpy/core/arrayprint.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -235,38 +235,44 @@ def repr_format(x):
235235
return repr(x)
236236

237237
def _get_formatdict(data, precision, suppress_small, formatter):
238-
formatdict = {'bool': _boolFormatter,
239-
'int': IntegerFormat(data),
240-
'float': FloatFormat(data, precision, suppress_small),
241-
'longfloat': LongFloatFormat(precision),
242-
'complexfloat': ComplexFormat(data, precision,
238+
# wrapped in lambdas to avoid taking a code path with the wrong type of data
239+
formatdict = {'bool': lambda: _boolFormatter,
240+
'int': lambda: IntegerFormat(data),
241+
'float': lambda: FloatFormat(data, precision, suppress_small),
242+
'longfloat': lambda: LongFloatFormat(precision),
243+
'complexfloat': lambda: ComplexFormat(data, precision,
243244
suppress_small),
244-
'longcomplexfloat': LongComplexFormat(precision),
245-
'datetime': DatetimeFormat(data),
246-
'timedelta': TimedeltaFormat(data),
247-
'numpystr': repr_format,
248-
'str': str}
245+
'longcomplexfloat': lambda: LongComplexFormat(precision),
246+
'datetime': lambda: DatetimeFormat(data),
247+
'timedelta': lambda: TimedeltaFormat(data),
248+
'numpystr': lambda: repr_format,
249+
'str': lambda: str}
250+
251+
# we need to wrap values in `formatter` in a lambda, so that the interface
252+
# is the same as the above values.
253+
def indirect(x):
254+
return lambda: x
249255

250256
if formatter is not None:
251257
fkeys = [k for k in formatter.keys() if formatter[k] is not None]
252258
if 'all' in fkeys:
253259
for key in formatdict.keys():
254-
formatdict[key] = formatter['all']
260+
formatdict[key] = indirect(formatter['all'])
255261
if 'int_kind' in fkeys:
256262
for key in ['int']:
257-
formatdict[key] = formatter['int_kind']
263+
formatdict[key] = indirect(formatter['int_kind'])
258264
if 'float_kind' in fkeys:
259265
for key in ['float', 'longfloat']:
260-
formatdict[key] = formatter['float_kind']
266+
formatdict[key] = indirect(formatter['float_kind'])
261267
if 'complex_kind' in fkeys:
262268
for key in ['complexfloat', 'longcomplexfloat']:
263-
formatdict[key] = formatter['complex_kind']
269+
formatdict[key] = indirect(formatter['complex_kind'])
264270
if 'str_kind' in fkeys:
265271
for key in ['numpystr', 'str']:
266-
formatdict[key] = formatter['str_kind']
272+
formatdict[key] = indirect(formatter['str_kind'])
267273
for key in formatdict.keys():
268274
if key in fkeys:
269-
formatdict[key] = formatter[key]
275+
formatdict[key] = indirect(formatter[key])
270276

271277
return formatdict
272278

@@ -289,28 +295,28 @@ def _get_format_function(data, precision, suppress_small, formatter):
289295
dtypeobj = dtype_.type
290296
formatdict = _get_formatdict(data, precision, suppress_small, formatter)
291297
if issubclass(dtypeobj, _nt.bool_):
292-
return formatdict['bool']
298+
return formatdict['bool']()
293299
elif issubclass(dtypeobj, _nt.integer):
294300
if issubclass(dtypeobj, _nt.timedelta64):
295-
return formatdict['timedelta']
301+
return formatdict['timedelta']()
296302
else:
297-
return formatdict['int']
303+
return formatdict['int']()
298304
elif issubclass(dtypeobj, _nt.floating):
299305
if issubclass(dtypeobj, _nt.longfloat):
300-
return formatdict['longfloat']
306+
return formatdict['longfloat']()
301307
else:
302-
return formatdict['float']
308+
return formatdict['float']()
303309
elif issubclass(dtypeobj, _nt.complexfloating):
304310
if issubclass(dtypeobj, _nt.clongfloat):
305-
return formatdict['longcomplexfloat']
311+
return formatdict['longcomplexfloat']()
306312
else:
307-
return formatdict['complexfloat']
313+
return formatdict['complexfloat']()
308314
elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)):
309-
return formatdict['numpystr']
315+
return formatdict['numpystr']()
310316
elif issubclass(dtypeobj, _nt.datetime64):
311-
return formatdict['datetime']
317+
return formatdict['datetime']()
312318
else:
313-
return formatdict['numpystr']
319+
return formatdict['numpystr']()
314320

315321
def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
316322
prefix="", formatter=None):
@@ -336,7 +342,6 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
336342
_summaryEdgeItems, summary_insert)[:-1]
337343
return lst
338344

339-
340345
def array2string(a, max_line_width=None, precision=None,
341346
suppress_small=None, separator=' ', prefix="",
342347
style=repr, formatter=None):

0 commit comments

Comments
 (0)