Skip to content

Kernel fails while using one-direction LSTM on intel-based Mac #91694

@hawkiyc

Description

@hawkiyc

🐛 Describe the bug

Dear Developers,

I have an intel-based Macbook pro with AMD GPU, and there is a bug while implementing one-direction LSTM. The kernel always dies if I am using the MPS device. However, the model is fine for running on the CPU. Besides that, everything is fine on both CPU or GPU if bidirectional is set to True. Please also note that this bug happens on both spyder and jupyter-lab, so it does not result from the IDE side.

Here is my code:

class BasciRNN(nn.Module):
    
    def __init__(self, d_input, n_hidden, 
                 n_layers, n_output, 
                 dropout, bidirection = False):
        super(BasciRNN, self).__init__()
        self.d_input = d_input
        self.n_hidden = n_hidden
        self.n_layers = n_layers
        self.n_output = n_output
        self.drop = dropout
        self.bidirection = 1 if bidirection is False else 2
        
        self.rnn = nn.LSTM(input_size=self.d_input,
                           hidden_size=self.n_hidden,
                           num_layers=self.n_layers,
                           batch_first=True,
                           dropout = self.drop,
                           bidirectional = bidirection
                           )
        
        self.fc = nn.Linear(self.n_hidden * self.bidirection, self.n_output)
        
        
    def forward(self, x):
        
        h0 = torch.zeros(self.n_layers * self.bidirection,
                         x.size(0), self.n_hidden).to(device)
        c0 = torch.zeros(self.n_layers * self.bidirection, 
                         x.size(0), self.n_hidden).to(device)
        
        out, _ = self.rnn(x, (h0, c0)) 
        out = self.fc(out[:,-1,:]) 
        
        return out
        
model = BasciRNN(d_input = 1, n_hidden = 12, 
                 n_layers = 8, n_output = 1,
                 dropout = .2, bidirection = False)
model.to(device)

And here is the error information

loc("total derivative last state"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/810eba08-405a-11ed-86e9-6af958a02716/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":228:0)): error: input types 'tensor<1x10x12xf32>' and 'tensor<1x195x12xf32>' are not broadcast compatible
LLVM ERROR: Failed to infer result type(s).

/Applications/Spyder.app/Contents/Resources/lib/python3.9/spyder/plugins/ipythonconsole/scripts/conda-activate.sh: line 18:  6845 Abort trap: 6           $CONDA_ENV_PYTHON -m spyder_kernels.console -f $SPYDER_KERNEL_SPEC


Fatal Python error: Aborted


Current thread 0x000070001adab000 (most recent call first):
<no Python frame>



Main thread:
Thread 0x00007ff852fc74c0 (most recent call first):
  File "/Users/hawkiyc/opt/anaconda3/envs/env_torch_test/lib/python3.8/site-packages/torch/autograd/__init__.py", line 197 in backward
  File "/Users/hawkiyc/opt/anaconda3/envs/env_torch_test/lib/python3.8/site-packages/torch/_tensor.py", line 488 in backward
  File "/Users/hawkiyc/Documents/python/Pytorch/LazyProgrammer/my_own_practise/RNN/LSTM_ Quadratic.py", line 134 in <module>
  File "/Users/hawkiyc/opt/anaconda3/envs/env_torch_test/lib/python3.8/site-packages/spyder_kernels/py3compat.py", line 356 in compat_exec
  File "/Users/hawkiyc/opt/anaconda3/envs/env_torch_test/lib/python3.8/site-packages/spyder_kernels/customize/spydercustomize.py", line 469 in exec_code
  File "/Users/hawkiyc/opt/anaconda3/envs/env_torch_test/lib/python3.8/site-packages/spyder_kernels/customize/spydercustomize.py", line 611 in _exec_file
  File "/Users/hawkiyc/opt/anaconda3/envs/env_torch_test/lib/python3.8/site-packages/spyder_kernels/customize/spydercustomize.py", line 524 in runfile
  File "/var/folders/b1/symv8mxd49s2gj03flyq7jj00000gn/T/ipykernel_6845/1566561879.py", line 1 in <module>


Restarting kernel...

I deeply appreciate your kind assistance and look forward to prompt fixing or solution.

Sincerely,

Versions

(env_torch_test) xxxxxx@xxxxxx ~ % python collect_env.py
Collecting environment information...
PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0.1 (x86_64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.13 (default, Mar 28 2022, 06:16:26) [Clang 12.0.0 ] (64-bit runtime)
Python platform: macOS-10.16-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.5
[pip3] torch==1.13.1
[pip3] torchaudio==0.13.1
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.14.1
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 h0a44026_0 pytorch
[conda] mkl 2021.4.0 hecd8cb5_637
[conda] mkl-service 2.4.0 py38h9ed2024_0
[conda] mkl_fft 1.3.1 py38h4ab4a9b_0
[conda] mkl_random 1.2.2 py38hb2f4e1b_0
[conda] numpy 1.21.5 py38h2e5f0a9_3
[conda] numpy-base 1.21.5 py38h3b1a694_3
[conda] pytorch 1.13.1 py3.8_0 pytorch
[conda] torchaudio 0.13.1 py38_cpu pytorch
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.14.1 py38_cpu pytorch

cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions