-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Description
- I have searched the issues of this repo and believe that this is not a duplicate.
- I have searched the documentation and believe that my question is not covered.
Issue
Hi everyone,
I am installing JAX using poetry. I run the command poetry add jax and it works fine but this installs the cpu version, as expected. To install the gpu/tpu version of JAX the documentation indicates that I have to run:
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
I understand I could run this specific command in my environment, but if I do this, I believe, it is not handled well by poetry. Is there a more poetry way of installing JAX for gpu/tpu ?
Thank you for the help :)