[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
Pretrained Weights#
Written by: Nils Lehmann
In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions’ recently introduced Multi-Weight API.
It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.
Setup#
First, we install TorchGeo.
[ ]:
%pip install torchgeo
Imports#
Next, we import TorchGeo.
[ ]:
%matplotlib inline
import timm
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18
Pretrained Weights#
Pretrained weights for torchgeo.models are available and sorted by satellite or sensor type: sensor-agnostic, Landsat, NAIP, Sentinel-1, and Sentinel-2. Refer to the model documentation for a complete list of weights. Choose from the provided pre-trained weights based on your specific use case.
While some weights only accept RGB channel input, some weights have been pretrained on Sentinel-2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel-2 data.
To use these weights, you can load them as follows:
[ ]:
all_weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
rgb_weights = ResNet50_Weights.SENTINEL2_RGB_MOCO
Weight Metadata#
This set of weights is a torchvision WeightEnum and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights.
Let’s inspect the metadata of the two pretrained weights we have just loaded:
[ ]:
# ResNet18_Weights.SENTINEL2_ALL_MOCO
print(f'Weight URL: {all_weights.url}')
print('Weight metadata:')
for key, value in all_weights.meta.items():
print(f' {key}: {value}')
[ ]:
# ResNet50_Weights.SENTINEL2_RGB_MOCO
print(f'Weight URL: {rgb_weights.url}')
print('Weight metadata:')
for key, value in rgb_weights.meta.items():
print(f' {key}: {value}')
Using Pretrained Weights for Training#
We can load the pretrained weights ResNet18_Weights.SENTINEL2_ALL_MOCO into a ResNet-18 model like this:
[ ]:
model = resnet18(all_weights)
Here, TorchGeo simply acts as a wrapper around timm. If you don’t want to use this wrapper, you can create a timm model directly and load the pretrained weights from TorchGeo as follows:
[ ]:
in_chans = all_weights.meta['in_chans']
model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10)
model.load_state_dict(all_weights.get_state_dict(progress=True), strict=False)
To train our pretrained model on a dataset we will make use of Lightning’s Trainer. For a more elaborate explanation of how TorchGeo uses Lightning, check out this next tutorial.