Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 3e676fc

Browse files
authored
Fix memory leaks in Gluon (#18328)
Fix leak of ndarray objects in the frontend due to reference cycle.
1 parent fec534a commit 3e676fc

File tree

3 files changed

+56
-12
lines changed

3 files changed

+56
-12
lines changed

python/mxnet/gluon/block.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323
import threading
2424
import copy
2525
import warnings
26-
import re
26+
import weakref
2727
from collections import OrderedDict, defaultdict
28+
29+
import re
2830
import numpy as np
2931

3032
from ..base import mx_real_t, MXNetError, NDArrayHandle, py_str
@@ -48,7 +50,7 @@ class _BlockScope(object):
4850
_current = threading.local()
4951

5052
def __init__(self, block):
51-
self._block = block
53+
self._block = weakref.ref(block) if block is not None else None
5254
self._counter = {}
5355
self._old_scope = None
5456
self._name_scope = None
@@ -60,7 +62,8 @@ def create(prefix, params, hint):
6062
The profiler scope is to support the GPU memory profiler.
6163
"""
6264
current = getattr(_BlockScope._current, "value", None)
63-
if current is None:
65+
block = current._block() if current is not None else None
66+
if current is None or block is None:
6467
if prefix is None:
6568
if not hasattr(_name.NameManager._current, "value"):
6669
_name.NameManager._current.value = _name.NameManager()
@@ -79,29 +82,31 @@ def create(prefix, params, hint):
7982
prefix = '%s%d_'%(hint, count)
8083
current._counter[hint] = count + 1
8184
if params is None:
82-
parent = current._block.params
85+
parent = block.params
8386
params = ParameterDict(parent.prefix+prefix, parent._shared)
8487
else:
8588
params = ParameterDict(params.prefix, params)
8689
# replace the trailing underscore with colon
8790
profiler_scope_name = (prefix[:-1] if prefix.endswith('_') \
8891
else prefix) + ":"
89-
return current._block.prefix + prefix, params, \
90-
current._block._profiler_scope_name + profiler_scope_name
92+
return block.prefix + prefix, params, \
93+
block._profiler_scope_name + profiler_scope_name
9194

9295
def __enter__(self):
93-
if self._block._empty_prefix:
96+
block = self._block()
97+
if block is None or block._empty_prefix:
9498
return self
9599
self._old_scope = getattr(_BlockScope._current, "value", None)
96100
_BlockScope._current.value = self
97-
self._name_scope = _name.Prefix(self._block.prefix)
101+
self._name_scope = _name.Prefix(block.prefix)
98102
self._name_scope.__enter__()
99-
self._profiler_scope = _profiler.Scope(self._block._profiler_scope_name)
103+
self._profiler_scope = _profiler.Scope(block._profiler_scope_name)
100104
self._profiler_scope.__enter__()
101105
return self
102106

103107
def __exit__(self, ptype, value, trace):
104-
if self._block._empty_prefix:
108+
block = self._block()
109+
if block is None or block._empty_prefix:
105110
return
106111
self._name_scope.__exit__(ptype, value, trace)
107112
self._name_scope = None

tests/python/unittest/test_gluon.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717

1818
import os
19+
import gc
1920

2021
import mxnet as mx
2122
from mxnet import gluon
@@ -3229,3 +3230,40 @@ def hybrid_forward(self, F, x):
32293230

32303231
mx.test_utils.assert_almost_equal(grad1, grad2)
32313232

3233+
def test_no_memory_leak_in_gluon():
3234+
# Collect all other garbage prior to this test. Otherwise the test may fail
3235+
# due to unrelated memory leaks.
3236+
gc.collect()
3237+
3238+
gc_flags = gc.get_debug()
3239+
gc.set_debug(gc.DEBUG_SAVEALL)
3240+
net = mx.gluon.nn.Dense(10, in_units=10)
3241+
net.initialize()
3242+
del net
3243+
gc.collect()
3244+
gc.set_debug(gc_flags) # reset gc flags
3245+
3246+
# Check for leaked NDArrays
3247+
seen = set()
3248+
def has_array(element):
3249+
try:
3250+
if element in seen:
3251+
return False
3252+
seen.add(element)
3253+
except TypeError: # unhashable
3254+
pass
3255+
3256+
if isinstance(element, mx.nd._internal.NDArrayBase):
3257+
return True
3258+
elif hasattr(element, '__dict__'):
3259+
return any(has_array(x) for x in vars(element))
3260+
elif isinstance(element, dict):
3261+
return any(has_array(x) for x in element.items())
3262+
else:
3263+
try:
3264+
return any(has_array(x) for x in element)
3265+
except (TypeError, KeyError):
3266+
return False
3267+
3268+
assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
3269+
del gc.garbage[:]

tests/python/unittest/test_thread_local.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ def __init__(self, prefix):
125125
status = [False]
126126
event = threading.Event()
127127
def f():
128-
with block._BlockScope(dummy_block("spawned_")):
129-
x= NameManager.current.get(None, "hello")
128+
net = dummy_block("spawned_") # BlockScope only keeps a weakref to the Block
129+
with block._BlockScope(net):
130+
x = NameManager.current.get(None, "hello")
130131
event.wait()
131132
if x == "spawned_hello0":
132133
status[0] = True

0 commit comments

Comments
 (0)