-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Mamba2 9 times slower inference time than Mamba1 #355
Copy link
Copy link
Closed
Description
After change the d_model, mamba2 worked in the simple test environment provided in the README. But I noticed that the mamba2 has a much slower speed than mamba1. I tested it, here is my code
import torch
from mamba_ssm import Mamba2 as Mamba
# from mamba_ssm import Mamba
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
end.record()
torch.cuda.synchronize()
inference_time = start.elapsed_time(end)
assert y.shape == x.shape
print(f'parameter number: {sum([p.numel() for p in model.parameters()])}')
print(f'inference time: {inference_time}')The result I got is this
Mamba1 parameter number: 511488
Mamba1 inference time: 539.1769409179688
Mamba2 parameter number: 431768
Mamba2 inference time: 4322.52294921875
I don't know if it is a bug or did I make a mistake. Please feel free to share your thoughts.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels