[DataLoader] Short circuit pin_memory recursion when operating on bytes#97737
[DataLoader] Short circuit pin_memory recursion when operating on bytes#97737ezhang887 wants to merge 1 commit intopytorch:masterfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97737
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2d9866f: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
|
cc @malfet for another silent breakage from the six removal PR :( |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
NivekT
left a comment
There was a problem hiding this comment.
LGTM! Thanks for finding and fixing this bug.
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that. [ghstack-poisoned]
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that. [ghstack-poisoned]
Hmm, for unit test can probably just check how many times that function got called recursively and make sure it doesn't get excessively called? Not sure if there's an easy way to do that though without modifying the source code though, maybe can be done with some monkeypatching magic. |
|
What about: |
Oh true, yup this works! (Tested locally) |
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that. Pull Request resolved: #97789 Approved by: https://github.com/albanD
Revisit `torch._six.string_classes` (which is `(str, bytes)`) removal: `isinstance(obj, string_classes) -> isinstance(obj, str)`.
Both `str` and `bytes` are `Sequence` classes.
```python
In [1]: from typing import Sequence
In [2]: issubclass(bytes, Sequence)
Out[2]: True
In [3]: issubclass(str, Sequence)
Out[3]: True
```
Re-add `bytes` to type guards like:
```python
def is_seq(obj):
return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))
```
Ref:
- #94709 (comment)
- #97737
- #97789
Pull Request resolved: #97863
Approved by: https://github.com/Skylion007, https://github.com/albanD
…es (pytorch#97737) Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099 I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in `py-spy` it looked like the pin_memory thread was doing a lot more work than before. Looking at a trace in `nsys` I saw the thread doing the forward pass having a bunch of `pthread_cond_timedwait` with GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL). After some debugging I found out the issue. If a `bytes` was passed into `pin_memory`, previously in 1.13 (before pytorch#94709) it would short-circuit and return here https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/utils/data/_utils/pin_memory.py#L54-L55 since `bytes` was in `torch._six.string_classes`: ``` >>> from torch._six import string_classes >>> string_classes (<class 'str'>, <class 'bytes'>) >>> ``` However after pytorch#94709, if a `bytes` was passed into `pin_memory` it would fall into here instead https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L68-L73 because the previous check is now doing `isinstance(data, str)` instead of `isinstance(data, (str, bytes))`! https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L56-L57 As a result, `pin_memory` gets called recursively for each element in the `bytes` leading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing. This PR simply changes `isinstance(data, str)` to `isinstance(data, (str, bytes))` to match the behavior before pytorch#94709 Pull Request resolved: pytorch#97737 Approved by: https://github.com/albanD, https://github.com/NivekT
Similar to pytorch#97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that. Pull Request resolved: pytorch#97789 Approved by: https://github.com/albanD
…97863) Revisit `torch._six.string_classes` (which is `(str, bytes)`) removal: `isinstance(obj, string_classes) -> isinstance(obj, str)`. Both `str` and `bytes` are `Sequence` classes. ```python In [1]: from typing import Sequence In [2]: issubclass(bytes, Sequence) Out[2]: True In [3]: issubclass(str, Sequence) Out[3]: True ``` Re-add `bytes` to type guards like: ```python def is_seq(obj): return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)) ``` Ref: - pytorch#94709 (comment) - pytorch#97737 - pytorch#97789 Pull Request resolved: pytorch#97863 Approved by: https://github.com/Skylion007, https://github.com/albanD
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that. Pull Request resolved: #97789 Approved by: https://github.com/albanD
…es (#97737) Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099 I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in `py-spy` it looked like the pin_memory thread was doing a lot more work than before. Looking at a trace in `nsys` I saw the thread doing the forward pass having a bunch of `pthread_cond_timedwait` with GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL). After some debugging I found out the issue. If a `bytes` was passed into `pin_memory`, previously in 1.13 (before #94709) it would short-circuit and return here https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/utils/data/_utils/pin_memory.py#L54-L55 since `bytes` was in `torch._six.string_classes`: ``` >>> from torch._six import string_classes >>> string_classes (<class 'str'>, <class 'bytes'>) >>> ``` However after #94709, if a `bytes` was passed into `pin_memory` it would fall into here instead https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L68-L73 because the previous check is now doing `isinstance(data, str)` instead of `isinstance(data, (str, bytes))`! https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L56-L57 As a result, `pin_memory` gets called recursively for each element in the `bytes` leading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing. This PR simply changes `isinstance(data, str)` to `isinstance(data, (str, bytes))` to match the behavior before #94709 Pull Request resolved: #97737 Approved by: https://github.com/albanD, https://github.com/NivekT
…97789, #97863) (#98055) * [DataLoader] Short circuit pin_memory recursion when operating on bytes (#97737) Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099 I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in `py-spy` it looked like the pin_memory thread was doing a lot more work than before. Looking at a trace in `nsys` I saw the thread doing the forward pass having a bunch of `pthread_cond_timedwait` with GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL). After some debugging I found out the issue. If a `bytes` was passed into `pin_memory`, previously in 1.13 (before #94709) it would short-circuit and return here https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/utils/data/_utils/pin_memory.py#L54-L55 since `bytes` was in `torch._six.string_classes`: ``` >>> from torch._six import string_classes >>> string_classes (<class 'str'>, <class 'bytes'>) >>> ``` However after #94709, if a `bytes` was passed into `pin_memory` it would fall into here instead https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L68-L73 because the previous check is now doing `isinstance(data, str)` instead of `isinstance(data, (str, bytes))`! https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L56-L57 As a result, `pin_memory` gets called recursively for each element in the `bytes` leading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing. This PR simply changes `isinstance(data, str)` to `isinstance(data, (str, bytes))` to match the behavior before #94709 Pull Request resolved: #97737 Approved by: https://github.com/albanD, https://github.com/NivekT * [DataLoader] Fix collation logic (#97789) Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that. Pull Request resolved: #97789 Approved by: https://github.com/albanD * Revisit `torch._six.string_classes` removal (#94709) (#97863) Revisit `torch._six.string_classes` (which is `(str, bytes)`) removal: `isinstance(obj, string_classes) -> isinstance(obj, str)`. Both `str` and `bytes` are `Sequence` classes. ```python In [1]: from typing import Sequence In [2]: issubclass(bytes, Sequence) Out[2]: True In [3]: issubclass(str, Sequence) Out[3]: True ``` Re-add `bytes` to type guards like: ```python def is_seq(obj): return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)) ``` Ref: - #94709 (comment) - #97737 - #97789 Pull Request resolved: #97863 Approved by: https://github.com/Skylion007, https://github.com/albanD --------- Co-authored-by: Eric Zhang <[email protected]> Co-authored-by: Kevin Tse <[email protected]>
Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099
I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in
py-spyit looked like the pin_memory thread was doing a lot more work than before. Looking at a trace innsysI saw the thread doing the forward pass having a bunch ofpthread_cond_timedwaitwith GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL).After some debugging I found out the issue. If a
byteswas passed intopin_memory, previously in 1.13 (before #94709) it would short-circuit and return herepytorch/torch/utils/data/_utils/pin_memory.py
Lines 54 to 55 in d922c29
since
byteswas intorch._six.string_classes:However after #94709, if a
byteswas passed intopin_memoryit would fall into here insteadpytorch/torch/utils/data/_utils/pin_memory.py
Lines 68 to 73 in c263bd4
because the previous check is now doing
isinstance(data, str)instead ofisinstance(data, (str, bytes))!pytorch/torch/utils/data/_utils/pin_memory.py
Lines 56 to 57 in c263bd4
As a result,
pin_memorygets called recursively for each element in thebytesleading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing.This PR simply changes
isinstance(data, str)toisinstance(data, (str, bytes))to match the behavior before #94709