Skip to content

Slower training speed? #156

@xyzhang626

Description

@xyzhang626

Thank the authors for the great work and making it open-sourced.
.
I made a minimal trainable mamba on Tinystories here based on llama2.c by a few lines. But found it is ~13% slower in training on my v100s (800ms v.s. 650ms per iter, 2048 seq_len) than torch.compiled Transformers.

Does it work as expected? I also notice the torch.compile can not directly work with the current mamba model in this repo. Is that one factor given the mamda has been equipped with serval fused ops.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions