@@ -34,6 +34,8 @@ def __init__(
3434 :param scales: list of scalars or None, if None, do not apply any scaling.
3535 :param kernel: gaussian or cauchy.
3636 :param reduction: using SUM reduction over batch axis,
37+ this is for supporting multi-device training,
38+ and the loss will be divided by global batch size,
3739 calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
3840 :param name: str, name of the loss.
3941 """
@@ -132,6 +134,8 @@ def __init__(
132134 :param scales: list of scalars or None, if None, do not apply any scaling.
133135 :param kernel: gaussian or cauchy.
134136 :param reduction: using SUM reduction over batch axis,
137+ this is for supporting multi-device training,
138+ and the loss will be divided by global batch size,
135139 calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
136140 :param name: str, name of the loss.
137141 """
@@ -206,6 +210,8 @@ def __init__(
206210 :param scales: list of scalars or None, if None, do not apply any scaling.
207211 :param kernel: gaussian or cauchy.
208212 :param reduction: using SUM reduction over batch axis,
213+ this is for supporting multi-device training,
214+ and the loss will be divided by global batch size,
209215 calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
210216 :param name: str, name of the loss.
211217 """
@@ -272,6 +278,8 @@ def __init__(
272278 :param scales: list of scalars or None, if None, do not apply any scaling.
273279 :param kernel: gaussian or cauchy.
274280 :param reduction: using SUM reduction over batch axis,
281+ this is for supporting multi-device training,
282+ and the loss will be divided by global batch size,
275283 calling the loss like `loss(y_true, y_pred)` will return a scalar tensor.
276284 :param name: str, name of the loss.
277285 """
0 commit comments