Describe the bug
import numpy as np
import torch
from monai.transforms import SplitChannel
t = SplitChannel()
_np = np.random.randint(2, size=(3, 3, 4))
_torch = torch.Tensor(_np)
print(t(_np)[0].shape)
print(t(_torch)[0].shape)
Output is:
(1, 3, 4)
torch.Size([3, 1, 4])
I suspect this is because slice acts differently on torch and numpy. Is this behaviour to be expected? I'm not sure if it was intentional to let the different modules follow their own behaviour or if we expect the two to give the same result.