Skip to content

[DataLoader] Short circuit pin_memory recursion when operating on bytes#97737

Closed
ezhang887 wants to merge 1 commit intopytorch:masterfrom
ezhang887:pin_memory_bytes_fix
Closed

[DataLoader] Short circuit pin_memory recursion when operating on bytes#97737
ezhang887 wants to merge 1 commit intopytorch:masterfrom
ezhang887:pin_memory_bytes_fix

Conversation

@ezhang887
Copy link
Copy Markdown
Contributor

Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099

I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in py-spy it looked like the pin_memory thread was doing a lot more work than before. Looking at a trace in nsys I saw the thread doing the forward pass having a bunch of pthread_cond_timedwait with GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL).

After some debugging I found out the issue. If a bytes was passed into pin_memory, previously in 1.13 (before #94709) it would short-circuit and return here

elif isinstance(data, string_classes):
return data

since bytes was in torch._six.string_classes:

>>> from torch._six import string_classes
>>> string_classes
(<class 'str'>, <class 'bytes'>)
>>>

However after #94709, if a bytes was passed into pin_memory it would fall into here instead

elif isinstance(data, collections.abc.Sequence):
try:
return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg]
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return [pin_memory(sample, device) for sample in data]

because the previous check is now doing isinstance(data, str) instead of isinstance(data, (str, bytes))!
elif isinstance(data, str):
return data

As a result, pin_memory gets called recursively for each element in the bytes leading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing.

This PR simply changes isinstance(data, str) to isinstance(data, (str, bytes)) to match the behavior before #94709

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/97737

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2d9866f:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: dataloader release notes category label Mar 28, 2023
@linux-foundation-easycla
Copy link
Copy Markdown

linux-foundation-easycla bot commented Mar 28, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

  • ✅ login: ezhang887 / name: Eric Zhang (2d9866f)

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Is a there a way to test this by any chance?

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Mar 28, 2023

cc @malfet for another silent breakage from the six removal PR :(

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Mar 28, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 28, 2023
@albanD albanD added this to the 2.0.1 milestone Mar 28, 2023
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Copy link
Copy Markdown
Contributor

@NivekT NivekT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for finding and fixing this bug.

pytorchmergebot pushed a commit that referenced this pull request Mar 28, 2023
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that.

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Mar 28, 2023
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that.

[ghstack-poisoned]
@ezhang887
Copy link
Copy Markdown
Contributor Author

Is a there a way to test this by any chance?

Hmm, for unit test can probably just check how many times that function got called recursively and make sure it doesn't get excessively called? Not sure if there's an easy way to do that though without modifying the source code though, maybe can be done with some monkeypatching magic.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Mar 28, 2023

What about: assert pin_memory(some_bytes) is some_bytes ?
To make sure this is a no-op for bytes?

@ezhang887
Copy link
Copy Markdown
Contributor Author

ezhang887 commented Mar 28, 2023

What about: assert pin_memory(some_bytes) is some_bytes ? To make sure this is a no-op for bytes?

Oh true, yup this works! (Tested locally)

pytorchmergebot pushed a commit that referenced this pull request Mar 28, 2023
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that.
Pull Request resolved: #97789
Approved by: https://github.com/albanD
pytorchmergebot pushed a commit that referenced this pull request Mar 30, 2023
Revisit `torch._six.string_classes` (which is `(str, bytes)`) removal: `isinstance(obj, string_classes) -> isinstance(obj, str)`.

Both `str` and `bytes` are `Sequence` classes.

```python
In [1]: from typing import Sequence

In [2]: issubclass(bytes, Sequence)
Out[2]: True

In [3]: issubclass(str, Sequence)
Out[3]: True
```

Re-add `bytes` to type guards like:

```python
def is_seq(obj):
    return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))
```

Ref:

- #94709 (comment)
- #97737
- #97789
Pull Request resolved: #97863
Approved by: https://github.com/Skylion007, https://github.com/albanD
XuehaiPan pushed a commit to XuehaiPan/pytorch that referenced this pull request Mar 31, 2023
…es (pytorch#97737)

Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099

I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in `py-spy` it looked like the pin_memory thread was doing a lot more work than before. Looking at a trace in `nsys` I saw the thread doing the forward pass having a bunch of `pthread_cond_timedwait` with GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL).

After some debugging I found out the issue. If a `bytes` was passed into `pin_memory`, previously in 1.13 (before pytorch#94709) it would short-circuit and return here
https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/utils/data/_utils/pin_memory.py#L54-L55
since `bytes` was in `torch._six.string_classes`:
```
>>> from torch._six import string_classes
>>> string_classes
(<class 'str'>, <class 'bytes'>)
>>>
```

However after pytorch#94709, if a `bytes` was passed into `pin_memory` it would fall into here instead
https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L68-L73
because the previous check is now doing `isinstance(data, str)` instead of `isinstance(data, (str, bytes))`!
https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L56-L57

As a result, `pin_memory` gets called recursively for each element in the `bytes` leading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing.

This PR simply changes `isinstance(data, str)` to `isinstance(data, (str, bytes))` to match the behavior before pytorch#94709

Pull Request resolved: pytorch#97737
Approved by: https://github.com/albanD, https://github.com/NivekT
XuehaiPan pushed a commit to XuehaiPan/pytorch that referenced this pull request Mar 31, 2023
Similar to pytorch#97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that.
Pull Request resolved: pytorch#97789
Approved by: https://github.com/albanD
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Mar 31, 2023
…97863)

Revisit `torch._six.string_classes` (which is `(str, bytes)`) removal: `isinstance(obj, string_classes) -> isinstance(obj, str)`.

Both `str` and `bytes` are `Sequence` classes.

```python
In [1]: from typing import Sequence

In [2]: issubclass(bytes, Sequence)
Out[2]: True

In [3]: issubclass(str, Sequence)
Out[3]: True
```

Re-add `bytes` to type guards like:

```python
def is_seq(obj):
    return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))
```

Ref:

- pytorch#94709 (comment)
- pytorch#97737
- pytorch#97789
Pull Request resolved: pytorch#97863
Approved by: https://github.com/Skylion007, https://github.com/albanD
NivekT added a commit that referenced this pull request Apr 4, 2023
Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that.
Pull Request resolved: #97789
Approved by: https://github.com/albanD
NivekT pushed a commit that referenced this pull request Apr 4, 2023
…es (#97737)

Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099

I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in `py-spy` it looked like the pin_memory thread was doing a lot more work than before. Looking at a trace in `nsys` I saw the thread doing the forward pass having a bunch of `pthread_cond_timedwait` with GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL).

After some debugging I found out the issue. If a `bytes` was passed into `pin_memory`, previously in 1.13 (before #94709) it would short-circuit and return here
https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/utils/data/_utils/pin_memory.py#L54-L55
since `bytes` was in `torch._six.string_classes`:
```
>>> from torch._six import string_classes
>>> string_classes
(<class 'str'>, <class 'bytes'>)
>>>
```

However after #94709, if a `bytes` was passed into `pin_memory` it would fall into here instead
https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L68-L73
because the previous check is now doing `isinstance(data, str)` instead of `isinstance(data, (str, bytes))`!
https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L56-L57

As a result, `pin_memory` gets called recursively for each element in the `bytes` leading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing.

This PR simply changes `isinstance(data, str)` to `isinstance(data, (str, bytes))` to match the behavior before #94709

Pull Request resolved: #97737
Approved by: https://github.com/albanD, https://github.com/NivekT
atalman pushed a commit that referenced this pull request Apr 5, 2023
…97789, #97863) (#98055)

* [DataLoader] Short circuit pin_memory recursion when operating on bytes (#97737)

Slack thread: https://pytorch.slack.com/archives/GEEQ2K4MD/p1679962409906099

I was seeing some massive (~2x) slowdowns on a job after running it on PyTorch 2.0. From some profiling in `py-spy` it looked like the pin_memory thread was doing a lot more work than before. Looking at a trace in `nsys` I saw the thread doing the forward pass having a bunch of `pthread_cond_timedwait` with GIL reacquire calls in it’s call stack, and it seemed like the thread doing the forward pass was getting blocked (waiting for the GIL) by the pin memory thread (which was holding the GIL).

After some debugging I found out the issue. If a `bytes` was passed into `pin_memory`, previously in 1.13 (before #94709) it would short-circuit and return here
https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/utils/data/_utils/pin_memory.py#L54-L55
since `bytes` was in `torch._six.string_classes`:
```
>>> from torch._six import string_classes
>>> string_classes
(<class 'str'>, <class 'bytes'>)
>>>
```

However after #94709, if a `bytes` was passed into `pin_memory` it would fall into here instead
https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L68-L73
because the previous check is now doing `isinstance(data, str)` instead of `isinstance(data, (str, bytes))`!
https://github.com/pytorch/pytorch/blob/c263bd43e8e8502d4726643bc6fd046f0130ac0e/torch/utils/data/_utils/pin_memory.py#L56-L57

As a result, `pin_memory` gets called recursively for each element in the `bytes` leading to a ton of wasted recursion. This also explains the slowdown / GIL contention I was seeing.

This PR simply changes `isinstance(data, str)` to `isinstance(data, (str, bytes))` to match the behavior before #94709

Pull Request resolved: #97737
Approved by: https://github.com/albanD, https://github.com/NivekT

* [DataLoader] Fix  collation logic (#97789)

Similar to #97737, a previous auto-refactor changed how `bytes` are handled during collation, which can potentially lead to performance regression. This PR undoes that.
Pull Request resolved: #97789
Approved by: https://github.com/albanD

* Revisit `torch._six.string_classes` removal (#94709) (#97863)

Revisit `torch._six.string_classes` (which is `(str, bytes)`) removal: `isinstance(obj, string_classes) -> isinstance(obj, str)`.

Both `str` and `bytes` are `Sequence` classes.

```python
In [1]: from typing import Sequence

In [2]: issubclass(bytes, Sequence)
Out[2]: True

In [3]: issubclass(str, Sequence)
Out[3]: True
```

Re-add `bytes` to type guards like:

```python
def is_seq(obj):
    return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))
```

Ref:

- #94709 (comment)
- #97737
- #97789
Pull Request resolved: #97863
Approved by: https://github.com/Skylion007, https://github.com/albanD

---------

Co-authored-by: Eric Zhang <[email protected]>
Co-authored-by: Kevin Tse <[email protected]>
@ezhang887 ezhang887 deleted the pin_memory_bytes_fix branch May 1, 2023 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: dataloader release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants