Skip to content

Commit b6a1ca4

Browse files
committed
shortcut
1 parent 9e72c9c commit b6a1ca4

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/optim/sgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _multi_tensor_sgd(params: List[Tensor],
280280

281281
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
282282
for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values():
283-
device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
283+
device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads)
284284

285285
if maximize:
286286
device_grads = torch._foreach_neg(device_grads)

0 commit comments

Comments
 (0)