Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thank you! Appreciate the detailed comments.
| pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) | ||
|
|
||
| def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): | ||
| def _test_torch_compile_with_group_offload_leaf(self, quantization_config, torch_dtype=torch.bfloat16): |
There was a problem hiding this comment.
Maybe we can test with parameterized where we test with and without streams?
|
@sayakpaul I'm not sure what's causing the tests to fail 🤔 This PR guards the compile test with torchao version/installation requirement but still seemingly causes tests to fail. I'll try to take a look later today if we don't have a quick understanding of what happened here |
|
Exactly! Nothing comes to mind as to what could trigger this! |
There was a problem hiding this comment.
Was able to spend some time and the following diff solves the problem:
Expand
diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py
index ddf97aca5..28454aae9 100644
--- a/tests/quantization/torchao/test_torchao.py
+++ b/tests/quantization/torchao/test_torchao.py
@@ -631,11 +631,14 @@ class TorchAoSerializationTest(unittest.TestCase):
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests):
- quantization_config = PipelineQuantizationConfig(
- quant_mapping={
- "transformer": TorchAoConfig(quant_type="int8_weight_only"),
- },
- )
+ @property
+ def quantization_config(self):
+ config = PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": TorchAoConfig(quant_type="int8_weight_only"),
+ },
+ )
+ return config
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)ChatGPT does a nice job of explaining what is happening:
https://chatgpt.com/share/685951bc-7c88-8013-b317-62683d1a1fa9. What I didn't investigate is that how come the other TorchAO tests are not getting flagged because of torchao installation errors 🤷
6d5f77e to
39faf5f
Compare
|
@sayakpaul Thanks for looking into it! I've converted all occurrences to properties (include bitsandbytes ones) |
| @property | ||
| def quantization_config(self): | ||
| raise NotImplementedError( | ||
| "This property should be implemented in the subclass to return the appropriate quantization config." | ||
| ) |
No description provided.