Skip to content

Conversation

@drbh
Copy link
Collaborator

@drbh drbh commented Jan 25, 2024

This draft PR is a work in progress implementation of the mamba model. This PR currently loads weights, and produces correct logits after a single pass.

This PR still needs to correctly integrate this model so it produces tokens as expected, and apply optimization to avoid all copies during runtime/unnecessary operations.

Helpful resources

Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
https://github.com/johnma2006/mamba-minimal
https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
huggingface/transformers#28094

Notes: this dev work is currently targeting state-spaces/mamba-130m, so if you want to test please use that model. Additionally when starting the router the prefill needs to be limited: cargo run -- --max-batch-prefill-tokens 768 --max-input-length 768

Update / Current State

Integration tests have been added and basic functionality such as model loading is supported.

cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
  • add tests
  • load model
  • make simple request
  • resolve warmup issue
  • resolve output issues

fetching models tested during dev

text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b

The server can be run

cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b

router

cargo run

make a request

curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq

response

{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}

@drbh drbh mentioned this pull request Jan 29, 2024
@RonanKMcGovern
Copy link

thanks, this will be a great addition as we see more mamba architectures

@drbh drbh marked this pull request as ready for review February 7, 2024 04:35
Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Narsil Narsil merged commit bd405e0 into main Feb 8, 2024
@Narsil Narsil deleted the impl-simple-mamba-model branch February 8, 2024 09:19
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Apr 29, 2024
This draft PR is a work in progress implementation of the mamba model.
This PR currently loads weights, and produces correct logits after a
single pass.

This PR still needs to correctly integrate this model so it produces
tokens as expected, and apply optimization to avoid all copies during
runtime/unnecessary operations.

[Mamba: Linear-Time Sequence Modeling with Selective State Spaces
(Albert Gu and Tri Dao)](https://arxiv.org/abs/2312.00752)
https://github.com/johnma2006/mamba-minimal

https://github.com/huggingface/candle/blob/main/candle-examples/examples/mamba-minimal/model.rs
huggingface/transformers#28094

Notes: this dev work is currently targeting `state-spaces/mamba-130m`,
so if you want to test please use that model. Additionally when starting
the router the prefill needs to be limited: `cargo run --
--max-batch-prefill-tokens 768 --max-input-length 768`

Integration tests have been added and basic functionality such as model
loading is supported.

```bash
cd integration-tests
pytest -vv models/test_fused_kernel_mamba.py
```
- [x] add tests
- [x] load model
- [x] make simple request
- [ ] resolve warmup issue
- [ ] resolve output issues

fetching models tested during dev
```bash
text-generation-server download-weights state-spaces/mamba-130m
text-generation-server download-weights state-spaces/mamba-1.4b
text-generation-server download-weights state-spaces/mamba-2.8b
```

The server can be run
```bash
cd server
 MASTER_ADDR=127.0.0.1 MASTER_PORT=5555 python text_generation_server/cli.py serve state-spaces/mamba-2.8b
```

router
```bash
cargo run
```

make a request
```bash
curl -s localhost:3000/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json' | jq
```

response
```json
{
  "generated_text": "\n\nDeep learning is a machine learning technique that uses a deep neural network to learn from data."
}
```

---------

Co-authored-by: Nicolas Patry <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants