Skip to content

Commit e2831d0

Browse files
authored
Lift closure cell update to earliest function (#461)
1 parent 8163e08 commit e2831d0

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

dill/_dill.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,8 @@ def __init__(self, *args, **kwds):
518518
self._strictio = False #_strictio
519519
self._fmode = settings['fmode'] if _fmode is None else _fmode
520520
self._recurse = settings['recurse'] if _recurse is None else _recurse
521-
self._postproc = {}
521+
from collections import OrderedDict
522+
self._postproc = OrderedDict()
522523

523524
def dump(self, obj): #NOTE: if settings change, need to update attributes
524525
# register if the object is a numpy ufunc
@@ -1424,14 +1425,14 @@ def save_cell(pickler, obj):
14241425
log.info("# Ce3")
14251426
return
14261427
if is_dill(pickler, child=True):
1427-
postproc = pickler._postproc.get(id(f))
1428+
postproc = next(iter(pickler._postproc.values()), None)
14281429
if postproc is not None:
14291430
log.info("Ce2: %s" % obj)
14301431
# _CELL_REF is defined in _shims.py to support older versions of
14311432
# dill. When breaking changes are made to dill, (_CELL_REF,) can
14321433
# be replaced by ()
1433-
postproc.append((_shims._setattr, (obj, 'cell_contents', f)))
14341434
pickler.save_reduce(_create_cell, (_CELL_REF,), obj=obj)
1435+
postproc.append((_shims._setattr, (obj, 'cell_contents', f)))
14351436
log.info("# Ce2")
14361437
return
14371438
log.info("Ce1: %s" % obj)
@@ -1748,16 +1749,37 @@ def save_function(pickler, obj):
17481749
postproc_list.append((dict.update, (globs, globs_copy)))
17491750

17501751
if PY3:
1752+
closure = obj.__closure__
17511753
fkwdefaults = getattr(obj, '__kwdefaults__', None)
17521754
_save_with_postproc(pickler, (_create_function, (
17531755
obj.__code__, globs, obj.__name__, obj.__defaults__,
1754-
obj.__closure__, obj.__dict__, fkwdefaults
1756+
closure, obj.__dict__, fkwdefaults
17551757
)), obj=obj, postproc_list=postproc_list)
17561758
else:
1759+
closure = obj.func_closure
17571760
_save_with_postproc(pickler, (_create_function, (
17581761
obj.func_code, globs, obj.func_name, obj.func_defaults,
1759-
obj.func_closure, obj.__dict__
1762+
closure, obj.__dict__
17601763
)), obj=obj, postproc_list=postproc_list)
1764+
1765+
# Lift closure cell update to earliest function (#458)
1766+
topmost_postproc = next(iter(pickler._postproc.values()), None)
1767+
if closure and topmost_postproc:
1768+
for cell in closure:
1769+
possible_postproc = (setattr, (cell, 'cell_contents', obj))
1770+
try:
1771+
topmost_postproc.remove(possible_postproc)
1772+
except ValueError:
1773+
continue
1774+
1775+
# Change the value of the cell
1776+
pickler.save_reduce(*possible_postproc)
1777+
# pop None created by calling preprocessing step off stack
1778+
if PY3:
1779+
pickler.write(bytes('0', 'UTF-8'))
1780+
else:
1781+
pickler.write('0')
1782+
17611783
log.info("# F1")
17621784
else:
17631785
log.info("F2: %s" % obj)

tests/test_recursive.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,19 @@ def test_recursive_function():
154154
fib = fib4
155155

156156

157+
def collection_function_recursion():
158+
d = {}
159+
def g():
160+
return d
161+
d['g'] = g
162+
return g
163+
164+
165+
def test_collection_function_recursion():
166+
g = copy(collection_function_recursion())
167+
assert g()['g'] is g
168+
169+
157170
if __name__ == '__main__':
158171
with warnings.catch_warnings():
159172
warnings.simplefilter('error')
@@ -163,3 +176,4 @@ def test_recursive_function():
163176
test_circular_reference()
164177
test_function_cells()
165178
test_recursive_function()
179+
test_collection_function_recursion()

0 commit comments

Comments
 (0)