Open in Studio Open in Colab
[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

Pretrained Weights#

Written by: Nils Lehmann

In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions’ recently introduced Multi-Weight API.

It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.

Setup#

First, we install TorchGeo.

[ ]:
%pip install torchgeo

Imports#

Next, we import TorchGeo.

[ ]:
%matplotlib inline

import timm

from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18

Pretrained Weights#

Pretrained weights for torchgeo.models are available and sorted by satellite or sensor type: sensor-agnostic, Landsat, NAIP, Sentinel-1, and Sentinel-2. Refer to the model documentation for a complete list of weights. Choose from the provided pre-trained weights based on your specific use case.

While some weights only accept RGB channel input, some weights have been pretrained on Sentinel-2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel-2 data.

To use these weights, you can load them as follows:

[ ]:
all_weights = ResNet18_Weights.SENTINEL2_ALL_MOCO

rgb_weights = ResNet50_Weights.SENTINEL2_RGB_MOCO

Weight Metadata#

This set of weights is a torchvision WeightEnum and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights.

Let’s inspect the metadata of the two pretrained weights we have just loaded:

[ ]:
# ResNet18_Weights.SENTINEL2_ALL_MOCO
print(f'Weight URL: {all_weights.url}')

print('Weight metadata:')
for key, value in all_weights.meta.items():
    print(f' {key}: {value}')
[ ]:
# ResNet50_Weights.SENTINEL2_RGB_MOCO
print(f'Weight URL: {rgb_weights.url}')

print('Weight metadata:')
for key, value in rgb_weights.meta.items():
    print(f' {key}: {value}')

Using Pretrained Weights for Training#

We can load the pretrained weights ResNet18_Weights.SENTINEL2_ALL_MOCO into a ResNet-18 model like this:

[ ]:
model = resnet18(all_weights)

Here, TorchGeo simply acts as a wrapper around timm. If you don’t want to use this wrapper, you can create a timm model directly and load the pretrained weights from TorchGeo as follows:

[ ]:
in_chans = all_weights.meta['in_chans']
model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10)
model.load_state_dict(all_weights.get_state_dict(progress=True), strict=False)

To train our pretrained model on a dataset we will make use of Lightning’s Trainer. For a more elaborate explanation of how TorchGeo uses Lightning, check out this next tutorial.