Skip to content

Segmentation fault (Python 3.6.0, Anaconda 4.3.0, Ubuntu 16.04.01) #1374

@aizvorski

Description

@aizvorski

Steps to reproduce:

diff --git a/train.py b/train.py
index dc5a31d..ceec980 100644
--- a/train.py
+++ b/train.py
@@ -84,9 +84,9 @@ if __name__ == '__main__':
         test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
         nlabels = 100
     train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
-                                               num_workers=args.prefetch, pin_memory=True)
+                                               num_workers=args.prefetch, pin_memory=False)
     test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_bs, shuffle=False,
-                                              num_workers=args.prefetch, pin_memory=True)
+                                              num_workers=args.prefetch, pin_memory=False)
 
     # Init checkpoints
     if not os.path.isdir(args.save):
@@ -109,7 +109,7 @@ if __name__ == '__main__':
         net.train()
         loss_avg = 0.0
         for batch_idx, (data, target) in enumerate(train_loader):
-            data, target = torch.autograd.Variable(data.cuda()), torch.autograd.Variable(target.cuda())
+            data, target = torch.autograd.Variable(data), torch.autograd.Variable(target)
 
             # forward
             output = net(data)
  • Run with python train.py --ngpu 0 --batch_size 8 data cifar10

Result: Segmentation fault (core dumped)

Debugging info:

ResNeXt.pytorch$ gdb python
GNU gdb (Ubuntu 7.11.1-0ubuntu1~16.04) 7.11.1
Copyright (C) 2016 Free Software Foundation, Inc.
License GPLv3+: GNU GPL version 3 or later <http://gnu.org/licenses/gpl.html>
This is free software: you are free to change and redistribute it.
There is NO WARRANTY, to the extent permitted by law.  Type "show copying"
and "show warranty" for details.
This GDB was configured as "x86_64-linux-gnu".
Type "show configuration" for configuration details.
For bug reporting instructions, please see:
<http://www.gnu.org/software/gdb/bugs/>.
Find the GDB manual and other documentation resources online at:
<http://www.gnu.org/software/gdb/documentation/>.
For help, type "help".
Type "apropos word" to search for commands related to "word"...
Reading symbols from python...done.
(gdb) r train.py --ngpu 0 --batch_size 8 data cifar10
Starting program: /home/alex/anaconda3/bin/python train.py --ngpu 0 --batch_size 8 data cifar10
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib/x86_64-linux-gnu/libthread_db.so.1".
Files already downloaded and verified
Files already downloaded and verified
CifarResNeXt (
  (conv_1_3x3): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn_1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
  (stage_1): Sequential (
    (stage_1_bottleneck_0): ResNeXtBottleneck (
      (conv_reduce): Conv2d(64, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
        (shortcut_conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (shortcut_bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      )
    )
    (stage_1_bottleneck_1): ResNeXtBottleneck (
      (conv_reduce): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
      )
    )
    (stage_1_bottleneck_2): ResNeXtBottleneck (
      (conv_reduce): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
      )
    )
  )
  (stage_2): Sequential (
    (stage_2_bottleneck_0): ResNeXtBottleneck (
      (conv_reduce): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
        (shortcut_conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (shortcut_bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      )
    )
    (stage_2_bottleneck_1): ResNeXtBottleneck (
      (conv_reduce): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
      )
    )
    (stage_2_bottleneck_2): ResNeXtBottleneck (
      (conv_reduce): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
      )
    )
  )
  (stage_3): Sequential (
    (stage_3_bottleneck_0): ResNeXtBottleneck (
      (conv_reduce): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
        (shortcut_conv): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (shortcut_bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      )
    )
    (stage_3_bottleneck_1): ResNeXtBottleneck (
      (conv_reduce): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
      )
    )
    (stage_3_bottleneck_2): ResNeXtBottleneck (
      (conv_reduce): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True)
      (conv_conv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=8, bias=False)
      (bn): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True)
      (conv_expand): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True)
      (shortcut): Sequential (
      )
    )
  )
  (classifier): Linear (1024 -> 10)
)
[New Thread 0x7fffbefa5780 (LWP 18919)]
[New Thread 0x7fffbeba4800 (LWP 18920)]
[New Thread 0x7fffbe7a3880 (LWP 18921)]
[New Thread 0x7fffab9b1700 (LWP 18923)]
[New Thread 0x7fffab1af980 (LWP 18924)]
[New Thread 0x7fffaadaea00 (LWP 18925)]
[New Thread 0x7fffaa9ada80 (LWP 18926)]

Thread 5 "python" received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x7fffab9b1700 (LWP 18923)]
0x00007fffedba4d04 in torch::autograd::cat (tensors=..., dim=dim@entry=0) from /home/alex/anaconda3/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
(gdb) where
#0  0x00007fffedba4d04 in torch::autograd::cat (tensors=..., dim=dim@entry=0) from /home/alex/anaconda3/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#1  0x00007fffedba6f1c in torch::autograd::ConvBackward::apply (this=0x2db1a278, grad_outputs=...) from /home/alex/anaconda3/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#2  0x00007fffedb8d138 in torch::autograd::call_function (task=...) from /home/alex/anaconda3/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#3  torch::autograd::Engine::evaluate_function (this=this@entry=0x7fffee408d00 <engine>, task=...) at torch/csrc/autograd/engine.cpp:136
#4  0x00007fffedb8ed3a in torch::autograd::Engine::thread_main (this=this@entry=0x7fffee408d00 <engine>, queue=...) from /home/alex/anaconda3/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#5  0x00007fffedb9f89a in PythonEngine::thread_main (this=0x7fffee408d00 <engine>, queue=...) from /home/alex/anaconda3/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so
#6  0x00007fffd20a5870 in ?? () from /home/alex/anaconda3/lib/python3.6/site-packages/torch/lib/../../../../libstdc++.so.6
#7  0x00007ffff76bc6ba in start_thread (arg=0x7fffab9b1700) at pthread_create.c:333
#8  0x00007ffff6ada82d in clone () at ../sysdeps/unix/sysv/linux/x86_64/clone.S:109

Versions:

$ python --version
Python 3.6.0 :: Anaconda 4.3.0 (64-bit)
$ python -c "import torch; print(torch.__version__)"
0.1.11+b13b701

Note: other networks seem to work okay, for example some of the pytorch examples. It is just this network that crashes.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions