|
1 | | -from typing import List, Tuple, Dict, Optional |
| 1 | +from typing import List, Tuple, Dict, Optional, Union |
2 | 2 |
|
3 | 3 | import torch |
4 | 4 | import torchvision |
@@ -326,3 +326,114 @@ def forward( |
326 | 326 | ) |
327 | 327 |
|
328 | 328 | return image, target |
| 329 | + |
| 330 | + |
| 331 | +class FixedSizeCrop(nn.Module): |
| 332 | + def __init__(self, size, fill=0, padding_mode="constant"): |
| 333 | + super().__init__() |
| 334 | + size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) |
| 335 | + self.crop_height = size[0] |
| 336 | + self.crop_width = size[1] |
| 337 | + self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch. |
| 338 | + self.padding_mode = padding_mode |
| 339 | + |
| 340 | + def _pad(self, img, target, padding): |
| 341 | + # Taken from the functional_tensor.py pad |
| 342 | + if isinstance(padding, int): |
| 343 | + pad_left = pad_right = pad_top = pad_bottom = padding |
| 344 | + elif len(padding) == 1: |
| 345 | + pad_left = pad_right = pad_top = pad_bottom = padding[0] |
| 346 | + elif len(padding) == 2: |
| 347 | + pad_left = pad_right = padding[0] |
| 348 | + pad_top = pad_bottom = padding[1] |
| 349 | + else: |
| 350 | + pad_left = padding[0] |
| 351 | + pad_top = padding[1] |
| 352 | + pad_right = padding[2] |
| 353 | + pad_bottom = padding[3] |
| 354 | + |
| 355 | + padding = [pad_left, pad_top, pad_right, pad_bottom] |
| 356 | + img = F.pad(img, padding, self.fill, self.padding_mode) |
| 357 | + if target is not None: |
| 358 | + target["boxes"][:, 0::2] += pad_left |
| 359 | + target["boxes"][:, 1::2] += pad_top |
| 360 | + if "masks" in target: |
| 361 | + target["masks"] = F.pad(target["masks"], padding, 0, "constant") |
| 362 | + |
| 363 | + return img, target |
| 364 | + |
| 365 | + def _crop(self, img, target, top, left, height, width): |
| 366 | + img = F.crop(img, top, left, height, width) |
| 367 | + if target is not None: |
| 368 | + boxes = target["boxes"] |
| 369 | + boxes[:, 0::2] -= left |
| 370 | + boxes[:, 1::2] -= top |
| 371 | + boxes[:, 0::2].clamp_(min=0, max=width) |
| 372 | + boxes[:, 1::2].clamp_(min=0, max=height) |
| 373 | + |
| 374 | + is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3]) |
| 375 | + |
| 376 | + target["boxes"] = boxes[is_valid] |
| 377 | + target["labels"] = target["labels"][is_valid] |
| 378 | + if "masks" in target: |
| 379 | + target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width) |
| 380 | + |
| 381 | + return img, target |
| 382 | + |
| 383 | + def forward(self, img, target=None): |
| 384 | + _, height, width = F.get_dimensions(img) |
| 385 | + new_height = min(height, self.crop_height) |
| 386 | + new_width = min(width, self.crop_width) |
| 387 | + |
| 388 | + if new_height != height or new_width != width: |
| 389 | + offset_height = max(height - self.crop_height, 0) |
| 390 | + offset_width = max(width - self.crop_width, 0) |
| 391 | + |
| 392 | + r = torch.rand(1) |
| 393 | + top = int(offset_height * r) |
| 394 | + left = int(offset_width * r) |
| 395 | + |
| 396 | + img, target = self._crop(img, target, top, left, new_height, new_width) |
| 397 | + |
| 398 | + pad_bottom = max(self.crop_height - new_height, 0) |
| 399 | + pad_right = max(self.crop_width - new_width, 0) |
| 400 | + if pad_bottom != 0 or pad_right != 0: |
| 401 | + img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom]) |
| 402 | + |
| 403 | + return img, target |
| 404 | + |
| 405 | + |
| 406 | +class RandomShortestSize(nn.Module): |
| 407 | + def __init__( |
| 408 | + self, |
| 409 | + min_size: Union[List[int], Tuple[int], int], |
| 410 | + max_size: int, |
| 411 | + interpolation: InterpolationMode = InterpolationMode.BILINEAR, |
| 412 | + ): |
| 413 | + super().__init__() |
| 414 | + self.min_size = [min_size] if isinstance(min_size, int) else list(min_size) |
| 415 | + self.max_size = max_size |
| 416 | + self.interpolation = interpolation |
| 417 | + |
| 418 | + def forward( |
| 419 | + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None |
| 420 | + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: |
| 421 | + _, orig_height, orig_width = F.get_dimensions(image) |
| 422 | + |
| 423 | + min_size = self.min_size[torch.randint(len(self.min_size), (1,)).item()] |
| 424 | + r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) |
| 425 | + |
| 426 | + new_width = int(orig_width * r) |
| 427 | + new_height = int(orig_height * r) |
| 428 | + |
| 429 | + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) |
| 430 | + |
| 431 | + if target is not None: |
| 432 | + target["boxes"][:, 0::2] *= new_width / orig_width |
| 433 | + target["boxes"][:, 1::2] *= new_height / orig_height |
| 434 | + if "masks" in target: |
| 435 | + target["masks"] = F.resize( |
| 436 | + target["masks"], [new_height, new_width], interpolation=InterpolationMode.NEAREST |
| 437 | + ) |
| 438 | + |
| 439 | + return image, target |
0 commit comments