-
Notifications
You must be signed in to change notification settings - Fork 74.6k
/
Copy pathd_variable.py
260 lines (229 loc) · 10.3 KB
/
d_variable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DTensor variable and saveable."""
import functools
from tensorflow.dtensor.python import api
from tensorflow.dtensor.python import layout as layout_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.trackable import base as trackable
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.util.tf_export import tf_export
class DSaveSpec(saveable_object.SaveSpec):
"""DTensor SaveSpec that additionaly captures global_shape and layout."""
def __init__(self,
tensor,
slice_spec,
name,
global_shape,
layout,
dtype=None,
device=None):
super().__init__(
tensor=tensor,
slice_spec=slice_spec,
name=name,
dtype=dtype,
device=device)
self.global_shape = global_shape
self.layout = layout
class _DVariableSaveable(saveable_object.SaveableObject):
"""Class for defining how to save/restore DTensor variable."""
def __init__(self, dvariable, name):
with ops.device(dvariable.device):
original_layout = api.fetch_layout(dvariable)
# Record original layout to allow restore.
self._original_layout = original_layout
self._dvariable = dvariable
def pack(tensors, layout):
with ops.device(dvariable.device):
return api.pack(tensors, layout)
host_layout = layout_lib.Layout(original_layout.sharding_specs,
original_layout.mesh.host_mesh())
def get_host_dtensor():
# Copy to host mesh if needed.
if original_layout.mesh.device_type().upper() != 'CPU':
# Prefer pack and unpack in eager mode because it supports sharded
# layouts.
if context.executing_eagerly():
host_dtensor = api.pack(
api.unpack(dvariable.read_value()), host_layout)
else:
host_dtensor = api.copy_to_mesh(dvariable.read_value(), host_layout)
else:
host_dtensor = dvariable.read_value()
return (math_ops.cast(host_dtensor, dtypes.bfloat16)
if self.should_cast(host_dtensor) else host_dtensor)
num_local_devices = original_layout.mesh.num_local_devices()
super(_DVariableSaveable, self).__init__(
None,
[
DSaveSpec(
tensor=get_host_dtensor,
slice_spec=pack([''] * num_local_devices,
layout_lib.Layout.replicated(
original_layout.mesh.host_mesh(), rank=0)),
name=pack([name] * num_local_devices,
layout_lib.Layout.replicated(
original_layout.mesh.host_mesh(), rank=0)),
global_shape=dvariable.shape,
# Layout is attached as attribute, no need to put it as a
# Tensor on DTensorDevice.
layout=host_layout.to_string(),
dtype=dtypes.bfloat16
if self.should_cast(dvariable) else dvariable.dtype,
device=dvariable.device)
],
name)
def should_cast(self, v):
"""Returns True if v has float32 dtype and is intructed to save as bf16.
Args:
v : The variable that determines whether to cast.
Returns:
True if current savable DVariable is instructed to save as bfloat16 and
the variable has dtype float32.
"""
return self._dvariable.save_as_bf16 and v.dtype == dtypes.float32
def restore(self, restored_tensors, restored_shapes):
"""Restores the same value into all variables."""
tensor, = restored_tensors
@def_function.function
def _restore(t):
with ops.device(self._dvariable.device):
return api.copy_to_mesh(t, self._original_layout)
# This assign establishes connections from restored tensor and tensors
# being restored to -- so that restore in SPMD can backtrack the DVariable
# and its layout, given that we're using tf.function style restore.
# Note that the restored dvaraible is on CPU no matter what as the restoreV2
# op must run on CPU.
# TODO(b/159035705): Allow restore for Tensor objects as well?
# Restore the dvariable back to original layout.
if self._original_layout.mesh.device_type().upper() != 'CPU':
tensor = _restore(tensor)
return self._dvariable.assign(
math_ops.cast(tensor, dtype=self._dvariable.dtype) if self._dvariable
.save_as_bf16 else tensor)
@tf_export('experimental.dtensor.DVariable', v1=[])
class DVariable(resource_variable_ops.ResourceVariable):
"""A replacement for tf.Variable which follows initial value placement.
The class also handles restore/save operations in DTensor. Note that,
DVariable may fall back to normal tf.Variable at this moment if
`initial_value` is not a DTensor.
"""
def __init__(self, initial_value, *args, dtype=None, **kwargs):
"""Overrides tf.Variable to fix VarHandleOp placements."""
# Variables by default use the current device scope for placement. This
# wrapper has them follow the initial value's placement instead (which will
# be the DTensor device if the initial value has a layout).
# Pop layout from kwargs since keras make_variable may pass a 'layout'
# keyword argument. We need to pop it because we are passing kwargs to
# super class constructor.
layout = kwargs.pop('layout', None)
shape = kwargs.get('shape', None)
if callable(initial_value):
unwrapped = initial_value
if issubclass(type(initial_value), functools.partial):
unwrapped = initial_value.func
# If wrapped is a CheckpointInitialValueCallable, this means that
# we are creating a Variable during a checkpoint restore.
# Thus the restore will happen now through this callable
# and we will create the DVariable with the restored dtensor.
if issubclass(type(unwrapped), trackable.CheckpointInitialValueCallable):
if not shape or not layout:
raise ValueError('Expected shape and layout to be not None.')
# CheckpointInitialValueCallable will call an eager tf.RestoreV2,
# which does not have any shape information or layout information
# attached. Thus we will do two things to have them correctly specified:
#
# The default layout scope allows us to correctly specify the output
# layout of the tf.RestoreV2 that will be called
#
# Passing shard_info with the correct shape allows the tf.RestoreV2
# ShapeInference to extract the shape.
initial_value = api.call_with_layout(
initial_value,
layout,
shard_info=trackable.ShardInfo(
shape=shape, offset=[0] * len(shape)))
else:
initial_value = initial_value()
# When the initial value came from a Checkpoint restoration, fetch tensor.
if isinstance(initial_value, trackable.CheckpointInitialValue):
initial_value = initial_value.wrapped_value
initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
variable_device = initial_value.device
self._save_as_bf16 = False
# TODO(b/159035705): The following code enables variable creation inside
# a tf.function. However, it requires a global dtensor device.
# if not variable_device and not tf.executing_eagerly():
# try:
# initial_value.op.get_attr("_layout")
# except ValueError:
# pass
# else:
# # The initial value is a DTensor, but because the DTensor device is
# # only active during eager execution at the moment we need to
# # translate that into a placement for the eager VarHandleOp.
# variable_device = _dtensor_device().name
with ops.device(variable_device):
# If initial tensor assigned to DVariable is DTensor, record the layout of
# the resource so that this can be queried.
if context.executing_eagerly():
if api.is_dtensor(initial_value):
value_layout = api.fetch_layout(initial_value)
if layout is not None and layout != value_layout:
raise errors_impl.InvalidArgumentError(
None,
None,
'Conflicting layout are provided for initial '
f'value layout ({value_layout}) and variable ({layout}).',
)
layout = value_layout
elif layout is not None:
initial_value = api.relayout(initial_value, layout)
else:
raise errors_impl.InvalidArgumentError(
None,
None,
'Neither layout nor DTensor initial value are provided.',
)
self.layout = layout
with api.default_mesh(layout.mesh):
super(DVariable, self).__init__(
initial_value, *args, dtype=dtype, **kwargs
)
else:
# FIXME(175928457): Record value layout in graph mode.
if layout is not None:
initial_value = api.relayout(initial_value, layout)
super(DVariable, self).__init__(
initial_value, *args, dtype=dtype, **kwargs)
@property
def save_as_bf16(self):
return self._save_as_bf16
@save_as_bf16.setter
def save_as_bf16(self, save_as_bf16):
"""Enables saving float32 as bfloat16."""
self._save_as_bf16 = save_as_bf16 and self.dtype == dtypes.float32
def _gather_saveables_for_checkpoint(self):
return {
trackable.VARIABLE_VALUE_KEY:
functools.partial(_DVariableSaveable, self)
}