Skip to content

Commit 51b90a7

Browse files
authored
Merge branch 'main' into ppwwyyxx-patch-1
2 parents 3bc6b89 + 1c9ccb7 commit 51b90a7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1296
-5981
lines changed

.circleci/config.yml

Lines changed: 13 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,24 @@ jobs:
263263
prototype_test:
264264
docker:
265265
- image: circleci/python:3.7
266+
resource_class: xlarge
266267
steps:
267268
- run:
268269
name: Install torch
269-
command: pip install --user --progress-bar=off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
270+
command: |
271+
pip install --user --progress-bar=off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
270272
- run:
271273
name: Install prototype dependencies
272274
command: pip install --user --progress-bar=off git+https://github.com/pytorch/data.git
273275
- checkout
276+
- run:
277+
name: Download model weights
278+
background: true
279+
command: |
280+
sudo apt update -qy && sudo apt install -qy parallel wget
281+
mkdir -p ~/.cache/torch/hub/checkpoints
282+
python scripts/collect_model_urls.py torchvision/prototype/models \
283+
| parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci'
274284
- run:
275285
name: Install torchvision
276286
command: pip install --user --progress-bar off --no-build-isolation .
@@ -279,6 +289,8 @@ jobs:
279289
command: pip install --user --progress-bar=off pytest pytest-mock scipy iopath
280290
- run:
281291
name: Run tests
292+
environment:
293+
PYTORCH_TEST_WITH_PROTOTYPE: 1
282294
command: pytest --junitxml=test-results/junit.xml -v --durations 20 test/test_prototype_*.py
283295
- store_test_results:
284296
path: test-results

ios/VisionTestApp/VisionTestApp.xcodeproj/project.pbxproj

Lines changed: 14 additions & 5536 deletions
Large diffs are not rendered by default.

references/segmentation/train.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,25 @@ def evaluate(model, data_loader, device, num_classes):
7272
return confmat
7373

7474

75-
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq):
75+
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, scaler=None):
7676
model.train()
7777
metric_logger = utils.MetricLogger(delimiter=" ")
7878
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
7979
header = f"Epoch: [{epoch}]"
8080
for image, target in metric_logger.log_every(data_loader, print_freq, header):
8181
image, target = image.to(device), target.to(device)
82-
output = model(image)
83-
loss = criterion(output, target)
82+
with torch.cuda.amp.autocast(enabled=scaler is not None):
83+
output = model(image)
84+
loss = criterion(output, target)
8485

8586
optimizer.zero_grad()
86-
loss.backward()
87-
optimizer.step()
87+
if scaler is not None:
88+
scaler.scale(loss).backward()
89+
scaler.step(optimizer)
90+
scaler.update()
91+
else:
92+
loss.backward()
93+
optimizer.step()
8894

8995
lr_scheduler.step()
9096

@@ -153,6 +159,8 @@ def main(args):
153159
params_to_optimize.append({"params": params, "lr": args.lr * 10})
154160
optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
155161

162+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
163+
156164
iters_per_epoch = len(data_loader)
157165
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
158166
optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9
@@ -186,6 +194,8 @@ def main(args):
186194
optimizer.load_state_dict(checkpoint["optimizer"])
187195
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
188196
args.start_epoch = checkpoint["epoch"] + 1
197+
if args.amp:
198+
scaler.load_state_dict(checkpoint["scaler"])
189199

190200
if args.test_only:
191201
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
@@ -196,7 +206,7 @@ def main(args):
196206
for epoch in range(args.start_epoch, args.epochs):
197207
if args.distributed:
198208
train_sampler.set_epoch(epoch)
199-
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
209+
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, scaler)
200210
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
201211
print(confmat)
202212
checkpoint = {
@@ -206,6 +216,8 @@ def main(args):
206216
"epoch": epoch,
207217
"args": args,
208218
}
219+
if args.amp:
220+
checkpoint["scaler"] = scaler.state_dict()
209221
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
210222
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
211223

@@ -269,6 +281,9 @@ def get_args_parser(add_help=True):
269281
# Prototype models only
270282
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
271283

284+
# Mixed precision training parameters
285+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
286+
272287
return parser
273288

274289

scripts/collect_model_urls.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pathlib
2+
import re
3+
import sys
4+
5+
MODEL_URL_PATTERN = re.compile(r"https://download[.]pytorch[.]org/models/.*?[.]pth")
6+
7+
8+
def main(root):
9+
model_urls = set()
10+
for path in pathlib.Path(root).glob("**/*"):
11+
if path.name.startswith("_") or not path.suffix == ".py":
12+
continue
13+
14+
with open(path, "r") as file:
15+
for line in file:
16+
model_urls.update(MODEL_URL_PATTERN.findall(line))
17+
18+
print("\n".join(sorted(model_urls)))
19+
20+
21+
if __name__ == "__main__":
22+
main(sys.argv[1])

test/common_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,29 @@
44
import random
55
import shutil
66
import tempfile
7+
from distutils.util import strtobool
78

89
import numpy as np
10+
import pytest
911
import torch
1012
from PIL import Image
1113
from torchvision import io
1214

1315
import __main__ # noqa: 401
1416

1517

16-
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true"
17-
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
18-
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
18+
def get_bool_env_var(name, *, exist_ok=False, default=False):
19+
value = os.getenv(name)
20+
if value is None:
21+
return default
22+
if exist_ok:
23+
return True
24+
return bool(strtobool(value))
25+
26+
27+
IN_CIRCLE_CI = get_bool_env_var("CIRCLECI")
28+
IN_RE_WORKER = get_bool_env_var("INSIDE_RE_WORKER", exist_ok=True)
29+
IN_FBCODE = get_bool_env_var("IN_FBCODE_TORCHVISION")
1930
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
2031
CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda."
2132

@@ -202,3 +213,7 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
202213
# scriptable function test
203214
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
204215
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)
216+
217+
218+
def run_on_env_var(name, *, skip_reason=None, exist_ok=False, default=False):
219+
return pytest.mark.skipif(not get_bool_env_var(name, exist_ok=exist_ok, default=default), reason=skip_reason)
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.
939 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)