-
Notifications
You must be signed in to change notification settings - Fork 74.6k
/
Copy pathnn_impl_distribute.py
142 lines (119 loc) · 5.28 KB
/
nn_impl_distribute.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Implementation of Neural Net (NN) functions with distribution strategy."""
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.losses import util as losses_util
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
@tf_export("nn.scale_regularization_loss")
@dispatch.add_dispatch_support
def scale_regularization_loss(regularization_loss):
"""Scales the sum of the given regularization losses by number of replicas.
Usage with distribution strategy and custom training loop:
```python
with strategy.scope():
def compute_loss(self, label, predictions):
per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions)
# Compute loss that is scaled by sample_weight and by global batch size.
loss = tf.nn.compute_average_loss(
per_example_loss,
sample_weight=sample_weight,
global_batch_size=GLOBAL_BATCH_SIZE)
# Add scaled regularization losses.
loss += tf.nn.scale_regularization_loss(tf.nn.l2_loss(weights))
return loss
```
Args:
regularization_loss: Regularization loss.
Returns:
Scalar loss value.
""" # pylint: disable=g-doc-exception
if (
distribute_lib.has_strategy()
and distribute_lib.in_cross_replica_context()
):
raise RuntimeError(
"You are calling `scale_regularization_loss` in cross replica context, "
"while it was expected to be called in replica context."
)
num_replicas = distribute_lib.get_strategy().num_replicas_in_sync
return math_ops.reduce_sum(regularization_loss) / num_replicas
@tf_export("nn.compute_average_loss")
@dispatch.add_dispatch_support
def compute_average_loss(
per_example_loss, sample_weight=None, global_batch_size=None
):
"""Scales per-example losses with sample_weights and computes their average.
Usage with distribution strategy and custom training loop:
```python
with strategy.scope():
def compute_loss(labels, predictions, sample_weight=None):
# If you are using a `Loss` class instead, set reduction to `NONE` so that
# we can do the reduction afterwards and divide by global batch size.
per_example_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions)
# Compute loss that is scaled by sample_weight and by global batch size.
return tf.nn.compute_average_loss(
per_example_loss,
sample_weight=sample_weight,
global_batch_size=GLOBAL_BATCH_SIZE)
```
Args:
per_example_loss: Per-example loss.
sample_weight: Optional weighting for each example.
global_batch_size: Optional global batch size value. Defaults to (size of
first dimension of `losses`) * (number of replicas).
Returns:
Scalar loss value, obtained by summing the `per_example_loss` and dividing
by `global_batch_size`. If `global_batch_size` is zero, the result is zero.
""" # pylint: disable=g-doc-exception
per_example_loss = ops.convert_to_tensor(per_example_loss)
input_dtype = per_example_loss.dtype
with losses_util.check_per_example_loss_rank(per_example_loss):
if sample_weight is not None:
sample_weight = ops.convert_to_tensor(sample_weight)
per_example_loss = losses_util.scale_losses_by_sample_weight(
per_example_loss, sample_weight
)
per_example_loss = math_ops.cast(per_example_loss, input_dtype)
if global_batch_size is None:
if (
distribute_lib.has_strategy()
and distribute_lib.in_cross_replica_context()
):
raise RuntimeError(
"You are calling `compute_average_loss` in cross replica context, "
"while it was expected to be called in replica context."
)
num_replicas = distribute_lib.get_strategy().num_replicas_in_sync
per_replica_batch_size = array_ops.shape_v2(per_example_loss)[0]
global_batch_size = per_replica_batch_size * num_replicas
check_ops.assert_scalar_v2(
global_batch_size, message="global_batch_size must be scalar."
)
check_ops.assert_integer_v2(
global_batch_size, message="global_batch_size must be an integer."
)
check_ops.assert_non_negative_v2(
global_batch_size, message="global_batch_size must be non-negative."
)
loss = math_ops.reduce_sum(per_example_loss)
global_batch_size = math_ops.cast(global_batch_size, input_dtype)
return math_ops.div_no_nan(loss, global_batch_size)