Skip to content

Calling cuda() et al. on deserialized cuda module segfaults #188

@adamlerer

Description

@adamlerer
import torch.cuda
import torch.nn as nn

print("serdes linear")
lin = nn.Linear(5,3)
lin.cuda() # necessary
with open('lin.pt', 'wb') as f: torch.save(lin, f)
with open('lin.pt', 'rb') as f: lin2 = torch.load(f)
lin2.cuda()
# segfault or cuda error

print("serdes rnn")
rnn = nn.RNN(5,3,2)
rnn.cuda() # necessary
with open('rnn.pt', 'wb') as f: torch.save(rnn, f)
with open('rnn.pt', 'rb') as f: rnn2 = torch.load(f)
rnn2.cuda()
# segfault or cuda error

print("done")

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions