Skip to content

Installing JAX with gpu/tpu support using poetry #5516

@pablo2909

Description

@pablo2909
  • I have searched the issues of this repo and believe that this is not a duplicate.
  • I have searched the documentation and believe that my question is not covered.

Issue

Hi everyone,

I am installing JAX using poetry. I run the command poetry add jax and it works fine but this installs the cpu version, as expected. To install the gpu/tpu version of JAX the documentation indicates that I have to run:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

I understand I could run this specific command in my environment, but if I do this, I believe, it is not handled well by poetry. Is there a more poetry way of installing JAX for gpu/tpu ?

Thank you for the help :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions