Charles Gallay, Michaël Defferrard, Nathanaël Perraudin
This work is a semester project done at EPFL as part of the master program of data science.
The goal of the project is to test different configuration in order to see the importance of being equivariant to rotation for a CNN.
As a first measure, classical CNN and equivariant to rotation CNN are compare on the task of image classification and more specifically on the CIFAR10 dataset.
-
Clone this repository.
git clone [email protected]:cgallay/GraphSymmetries.git cd GraphSymmetries
-
Install the dependencies.
pip install -r requirements.txt pip install -e .
import torch.nn as nn
from graphSym.graph_conv import GridGraphConv
from graphSym.graph_pool import GraphMaxPool2d
class Net(nn.Module):
"""
Network with Horizontal symetry
"""
def __init__(self, input_shape=(32,32), nb_class=5):
super().__init__()
underlying_graphs = [['left', 'right'], ['top'], ['bottom']]
conv1 = GridGraphConv(3, 30, merge_way='cat', underlying_graphs=underlying_graphs)
conv2 = GridGraphConv(30, 60, merge_way='cat', underlying_graphs=underlying_graphs)
pool1 = GraphMaxPool2d(input_shape=input_shape)
out_shape = (16, 16)
conv3 = GridGraphConv(60, 60, input_shape=out_shape, merge_way='cat', underlying_graphs=underlying_graphs)
conv4 = GridGraphConv(60, nb_class, input_shape=out_shape, merge_way='mean', underlying_graphs=underlying_graphs)
self.seq = nn.Sequential(conv1, conv2, pool1, conv3, conv4)
def forward(self, x):
out = self.seq(x)
out = out.mean(2)
return out
net = Net()The content of this repository is released under the terms of the MIT license