2323import threading
2424import copy
2525import warnings
26- import re
26+ import weakref
2727from collections import OrderedDict , defaultdict
28+
29+ import re
2830import numpy as np
2931
3032from ..base import mx_real_t , MXNetError , NDArrayHandle , py_str
@@ -48,7 +50,7 @@ class _BlockScope(object):
4850 _current = threading .local ()
4951
5052 def __init__ (self , block ):
51- self ._block = block
53+ self ._block = weakref . ref ( block ) if block is not None else None
5254 self ._counter = {}
5355 self ._old_scope = None
5456 self ._name_scope = None
@@ -60,7 +62,8 @@ def create(prefix, params, hint):
6062 The profiler scope is to support the GPU memory profiler.
6163 """
6264 current = getattr (_BlockScope ._current , "value" , None )
63- if current is None :
65+ block = current ._block () if current is not None else None
66+ if current is None or block is None :
6467 if prefix is None :
6568 if not hasattr (_name .NameManager ._current , "value" ):
6669 _name .NameManager ._current .value = _name .NameManager ()
@@ -79,29 +82,31 @@ def create(prefix, params, hint):
7982 prefix = '%s%d_' % (hint , count )
8083 current ._counter [hint ] = count + 1
8184 if params is None :
82- parent = current . _block .params
85+ parent = block .params
8386 params = ParameterDict (parent .prefix + prefix , parent ._shared )
8487 else :
8588 params = ParameterDict (params .prefix , params )
8689 # replace the trailing underscore with colon
8790 profiler_scope_name = (prefix [:- 1 ] if prefix .endswith ('_' ) \
8891 else prefix ) + ":"
89- return current . _block .prefix + prefix , params , \
90- current . _block ._profiler_scope_name + profiler_scope_name
92+ return block .prefix + prefix , params , \
93+ block ._profiler_scope_name + profiler_scope_name
9194
9295 def __enter__ (self ):
93- if self ._block ._empty_prefix :
96+ block = self ._block ()
97+ if block is None or block ._empty_prefix :
9498 return self
9599 self ._old_scope = getattr (_BlockScope ._current , "value" , None )
96100 _BlockScope ._current .value = self
97- self ._name_scope = _name .Prefix (self . _block .prefix )
101+ self ._name_scope = _name .Prefix (block .prefix )
98102 self ._name_scope .__enter__ ()
99- self ._profiler_scope = _profiler .Scope (self . _block ._profiler_scope_name )
103+ self ._profiler_scope = _profiler .Scope (block ._profiler_scope_name )
100104 self ._profiler_scope .__enter__ ()
101105 return self
102106
103107 def __exit__ (self , ptype , value , trace ):
104- if self ._block ._empty_prefix :
108+ block = self ._block ()
109+ if block is None or block ._empty_prefix :
105110 return
106111 self ._name_scope .__exit__ (ptype , value , trace )
107112 self ._name_scope = None
0 commit comments