Skip to content

Mamba2 9 times slower inference time than Mamba1 #355

@realwenlongwang

Description

@realwenlongwang

After change the d_model, mamba2 worked in the simple test environment provided in the README. But I noticed that the mamba2 has a much slower speed than mamba1. I tested it, here is my code

import torch
from mamba_ssm import Mamba2 as Mamba
# from mamba_ssm import Mamba

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
end.record()
torch.cuda.synchronize()
inference_time = start.elapsed_time(end)
assert y.shape == x.shape
print(f'parameter number: {sum([p.numel() for p in model.parameters()])}')
print(f'inference time: {inference_time}')

The result I got is this

Mamba1 parameter number: 511488
Mamba1 inference time: 539.1769409179688
Mamba2 parameter number: 431768
Mamba2 inference time: 4322.52294921875

I don't know if it is a bug or did I make a mistake. Please feel free to share your thoughts.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions