Skip to content

Commit 0bd86b1

Browse files
authored
1685-upgrade base image to 21.02 (#1687)
* fixes #1685 Signed-off-by: Wenqi Li <[email protected]> * add temp test Signed-off-by: Wenqi Li <[email protected]> * adds docstring Signed-off-by: Wenqi Li <[email protected]> * fixes dist sampler Signed-off-by: Wenqi Li <[email protected]> * remove temp tests Signed-off-by: Wenqi Li <[email protected]> * fixes type hint issue Signed-off-by: Wenqi Li <[email protected]>
1 parent 114faf0 commit 0bd86b1

File tree

6 files changed

+16
-10
lines changed

6 files changed

+16
-10
lines changed

.github/workflows/cron.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
cron-pt-image:
5757
if: github.repository == 'Project-MONAI/MONAI'
5858
container:
59-
image: nvcr.io/nvidia/pytorch:20.12-py3 # testing with the latest pytorch base image
59+
image: nvcr.io/nvidia/pytorch:21.02-py3 # testing with the latest pytorch base image
6060
options: "--gpus all"
6161
runs-on: [self-hosted, linux, x64, common]
6262
steps:

Dockerfile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:20.12-py3
13-
12+
# To build with a different base image
13+
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
14+
ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:21.02-py3
1415
FROM ${PYTORCH_IMAGE}
1516

1617
LABEL maintainer="[email protected]"

monai/data/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -922,10 +922,6 @@ class DistributedSampler(_TorchDistributedSampler):
922922
"""
923923

924924
def __init__(self, even_divisible: bool = True, *args, **kwargs):
925-
self.total_size: int = 0
926-
self.rank: int = 0
927-
self.num_samples: int = 0
928-
self.num_replicas: int = 0
929925
super().__init__(*args, **kwargs)
930926

931927
if not even_divisible:

monai/networks/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import warnings
1616
from contextlib import contextmanager
17-
from typing import Any, Callable, Optional, Sequence, cast
17+
from typing import Any, Callable, Optional, Sequence
1818

1919
import torch
2020
import torch.nn as nn
@@ -86,10 +86,10 @@ def predict_segmentation(
8686
threshold: thresholding the prediction values if multi-labels task.
8787
"""
8888
if not mutually_exclusive:
89-
return (cast(torch.Tensor, logits >= threshold)).int()
89+
return (logits >= threshold).int()
9090
if logits.shape[1] == 1:
9191
warnings.warn("single channel prediction, `mutually_exclusive=True` ignored, use threshold instead.")
92-
return (cast(torch.Tensor, logits >= threshold)).int()
92+
return (logits >= threshold).int()
9393
return logits.argmax(1, keepdim=True)
9494

9595

tests/test_distributed_sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def test_even(self):
2424
data = [1, 2, 3, 4, 5]
2525
sampler = DistributedSampler(dataset=data, shuffle=False)
2626
samples = np.array([data[i] for i in list(sampler)])
27+
self.assertEqual(dist.get_rank(), sampler.rank)
2728
if dist.get_rank() == 0:
2829
np.testing.assert_allclose(samples, np.array([1, 3, 5]))
2930

@@ -35,6 +36,7 @@ def test_uneven(self):
3536
data = [1, 2, 3, 4, 5]
3637
sampler = DistributedSampler(dataset=data, shuffle=False, even_divisible=False)
3738
samples = np.array([data[i] for i in list(sampler)])
39+
self.assertEqual(dist.get_rank(), sampler.rank)
3840
if dist.get_rank() == 0:
3941
np.testing.assert_allclose(samples, np.array([1, 3, 5]))
4042

tests/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import queue
1717
import sys
1818
import tempfile
19+
import time
1920
import traceback
2021
import unittest
2122
import warnings
@@ -273,6 +274,7 @@ def run_process(self, func, local_rank, args, kwargs, results):
273274
os.environ["RANK"] = str(self.nproc_per_node * self.node_rank + local_rank)
274275

275276
if torch.cuda.is_available():
277+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
276278
torch.cuda.set_device(int(local_rank))
277279

278280
dist.init_process_group(
@@ -283,6 +285,11 @@ def run_process(self, func, local_rank, args, kwargs, results):
283285
rank=int(os.environ["RANK"]),
284286
)
285287
func(*args, **kwargs)
288+
# the primary node lives longer to
289+
# avoid _store_based_barrier, RuntimeError: Broken pipe
290+
# as the TCP store daemon is on the rank 0
291+
if int(os.environ["RANK"]) == 0:
292+
time.sleep(0.1)
286293
results.put(True)
287294
except Exception as e:
288295
results.put(False)

0 commit comments

Comments
 (0)