import torch.cuda
import torch.nn as nn
print("serdes linear")
lin = nn.Linear(5,3)
lin.cuda() # necessary
with open('lin.pt', 'wb') as f: torch.save(lin, f)
with open('lin.pt', 'rb') as f: lin2 = torch.load(f)
lin2.cuda()
# segfault or cuda error
print("serdes rnn")
rnn = nn.RNN(5,3,2)
rnn.cuda() # necessary
with open('rnn.pt', 'wb') as f: torch.save(rnn, f)
with open('rnn.pt', 'rb') as f: rnn2 = torch.load(f)
rnn2.cuda()
# segfault or cuda error
print("done")