Skip to content

Conversation

@AlexDenisov
Copy link
Contributor

@AlexDenisov AlexDenisov commented Jul 10, 2024

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/130429

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit c6dc5e8 with merge base 4dbecf3 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@AlexDenisov AlexDenisov force-pushed the alexdenisov/inductor-annotations branch from dda23e2 to deb2f25 Compare August 13, 2024 08:16
@AlexDenisov AlexDenisov marked this pull request as ready for review August 13, 2024 08:17
@AlexDenisov
Copy link
Contributor Author

Making it ready for review as a gentle ping

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 14, 2024
@janeyx99 janeyx99 requested a review from aorenste August 14, 2024 00:36
@janeyx99
Copy link
Contributor

@aorenste assigning you as reviewer but pls reassign if there's a better reviewer

@aorenste aorenste requested a review from eellison August 14, 2024 04:06
@aorenste
Copy link
Contributor

This looks totally reasonable to me - but I don't know enough about the interactions to be comfortable reviewing this. Assigning to @eellison to either review or forward to someone who knows this bit better.

@aorenste aorenste removed their request for review August 14, 2024 04:06
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have a few comments - mostly, wonder if we could share the codgen of the buffer annotations with our pytorch profiler codegen which is doing a similar thing

Comment on lines 367 to 373
@property
def is_inference(self):
return _is_inference._get_handler()

@property
def is_backward(self):
return _is_backward._get_handler()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: V.graph.is_inference already exists, see:

self.is_inference = is_inference
.

Could we just add is_backward to GraphLowering object and query that for the is_backward property, instead of adding this ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally, I missed the is_inference property. Will add is_backward there as well and revert this commit!

Copy link
Contributor Author

@AlexDenisov AlexDenisov Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted this commit and added is_backward to GraphLowering 6a27b68 720b971


def device_range_start(msg) -> int:
"""
TBD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? add full docstring ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends on whether RangeHandle approach from above is the right way to go. I'll update the doc to cover the current implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know a ton about nvtx.. not maybe @ezyang or @eqy might be able to weigh in

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some docs 6390c08
Please, let me know if I can make it clearer 🙇


namespace torch::cuda::shared {

struct RangeHandle {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ezyang, I see that you first added nvrtx bindings (7 years ago). do you want to take a look ?

import random
import os
import tempfile
from torch.cuda import nvtx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add this conditionally ? see

if config.benchmark_kernel:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 3065 to 3074
if config.annotate_buffers:
V.graph.wrapper_code.writeline("nvtx.device_range_end(buffer_annotation)")

if self.current_device and device_need_guard(self.current_device.type):
# exit the outermost CUDA device guard. this is
# important for nested indentation codegen-ing.
V.graph.wrapper_code.codegen_device_guard_exit()

if config.annotate_training:
V.graph.wrapper_code.writeline("nvtx.device_range_end(training_annotation)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind posting an example output code for a fused kernel ? this would also be a good candidate for a test see run_and_get_code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a couple of test cases here 7d623ef
Example output is in the comment below: #130429 (comment)

Comment on lines 3065 to 3066
if config.annotate_buffers:
V.graph.wrapper_code.writeline("nvtx.device_range_end(buffer_annotation)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty cuda specific codegen for the general scheduler.. also, I wonder if any code here could be shared with the pytorch profiler. cc @davidberard98

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with torch._C._profiler._RecordFunctionFast(

for profiler, we put the code in the runtime so we don't fill the codegen with profiling annotations - is this viable for your use case?

I'm not sure how to make it less cuda-specific unfortunately though. @sraikund16 tells me that the NVTX handling in profiler is somewhat cuda-specific

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for profiler, we put the code in the runtime so we don't fill the codegen with profiling annotations - is this viable for your use case?

It seems the profiler works at a more fine grained level? i.e. it "wraps" kernel runs into profiling events? These annotations work at a slightly higher level/granularity. I think the training annotations (bw/fw/inference) can be moved outside of the codegen, but I'm not sure about the "buffer" annotations 🤔

Re: CUDA specific: this is indeed the case, not sure how to handle this properly. I can think of emitting some "abstract" Begin/EndAnnotationLine with nvtx calls hidden there, but I'm not certain if it brings much value? Happy to find the right solution 🙌

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer annotation, like the profiler, is wrapping a single run of a kernel.

        buffer_annotation = nvtx.device_range_start('op1')
        buf1 = empty_strided_cuda((5, ), (1, ), torch.float32)
        # Topologically Sorted Source Nodes: [mul], Original ATen: [aten.mul]
        triton_poi_fused_mul_1.run(arg0_1, arg1_1, buf1, 5, grid=grid(5), stream=stream0)
        del arg0_1
        del arg1_1
        nvtx.device_range_end(buffer_annotation)

I think it would be an improvement to put the logic here in the same place as the profiler.

As for as CUDA specific - I commented elsewhere about moving the codegen logic elsewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it would be an improvement, but it would change the granularity of the annotations a bit: current "buffer annotations" cover kernel run and all the memory allocations/deallocations, i.e. "annotate everything that happens to compute the buffer." Wrapping kernel runs would be more of an "annotate kernels" which is somewhat orthogonal to the current version. Perhaps it could be a third annotation option? 🤔

Copy link
Contributor

@eellison eellison Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, i'm not convinced that this is really that meaningful. It's not tracking memory (i.e, dont know when a particular buffer is allocated/deallocated), and the memory allocations/deallocations themselves are extremely cheap and happen on a different timeline than cuda because cuda is async. In any case, lets move this out of scheduler.py if you are convinced on keeping buffer or put in profiler

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, lets move this out of scheduler.py if you are convinced on keeping buffer or put in profiler

Apologies for the back and forth, just to ensure I understand this correctly: moving this into profiler implies that the models must use the profiler and the annotations would only appear in case the profiling is enabled?
if that's the case I'd prefer to not move it into profiler as it makes the integration into an arbitrary model harder.

Additionally, there are two more concerns:

  • it doesn't seem like the buffer names are available around the run method? So the best we can do is to add annotations based on the kernel name, but the same kernel can be used to compute several distinct buffers. It's of course possible to pass the buffer names around, but it doesn't look like a particularly good idea?
  • the profiler is called from within triton, which would miss non-triton kernels in case of mixed execution environment

For moving it into wrapper.py, I guess it'd require wrapping all the kernel calls into special lines (e.g. KernelCallLine) and adding another check here? I don't see any special *Line classes for such invocations (the last time I checked they were simply Python strings at the wrapper level). The buffer names are also missing at this level, though.

Does it make sense? Or I'm missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison I moved training annotations into wrapper.py 78c1c7e

Regarding the buffer annotations: I cannot find a better place to emit this code than within Scheduler's codegen.
It's possible to make it more general by introducing special, generic *Lines and moving the actual emission (together with the config check) into wrapper.py, but that would still leave "traces" in the scheduler.

If adding this to Scheduler is a strong no-go, then I could remove it from this PR leaving only "training annotations" here. I'd be happy to open a followup PR with buffer/kernel annotations for further discussion.

@AlexDenisov
Copy link
Contributor Author

Hi @eellison, thank you so much for the review, highly appreciated!

My initial goal was to have a discussion on whether such a feature would be useful and I take the comments here so far as a "yes."

I'll address the comments and bring the PR into a better shape.

One question: is it OK to rebase+force-push? Cannot find guidance on this matter in docs/wiki.

@AlexDenisov
Copy link
Contributor Author

Example code for a small snippet:

def f(a, b):
    return a + b, a * b
  1. Training annotations:
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    training_annotation = nvtx.device_range_start('inference')
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((5, ), (1, ), torch.float32)
        # Topologically Sorted Source Nodes: [add], Original ATen: [aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 5, grid=grid(5), stream=stream0)
        del arg0_1
        del arg1_1
    nvtx.device_range_end(training_annotation)
    return (buf0, )
  1. Buffer annotations (no fusion):
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    buffer_annotation = nvtx.device_range_start('op0')
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((5, ), (1, ), torch.float32)
        # Topologically Sorted Source Nodes: [add], Original ATen: [aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 5, grid=grid(5), stream=stream0)
        nvtx.device_range_end(buffer_annotation)
        buffer_annotation = nvtx.device_range_start('op1')
        buf1 = empty_strided_cuda((5, ), (1, ), torch.float32)
        # Topologically Sorted Source Nodes: [mul], Original ATen: [aten.mul]
        triton_poi_fused_mul_1.run(arg0_1, arg1_1, buf1, 5, grid=grid(5), stream=stream0)
        del arg0_1
        del arg1_1
        nvtx.device_range_end(buffer_annotation)
    return (buf0, buf1, )
  1. Buffer annotations (with fusion):
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    buffer_annotation = nvtx.device_range_start('op0_op1')
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((5, ), (1, ), torch.float32)
        buf1 = empty_strided_cuda((5, ), (1, ), torch.float32)
        # Topologically Sorted Source Nodes: [add, mul], Original ATen: [aten.add, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_mul_0.run(arg0_1, arg1_1, buf0, buf1, 5, grid=grid(5), stream=stream0)
        del arg0_1
        del arg1_1
        nvtx.device_range_end(buffer_annotation)
    return (buf0, buf1, )
  1. Buffer and training annotation combined:
def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (5, ), (1, ))
    assert_size_stride(arg1_1, (5, ), (1, ))
    training_annotation = nvtx.device_range_start('inference')
    buffer_annotation = nvtx.device_range_start('op0_op1')
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((5, ), (1, ), torch.float32)
        buf1 = empty_strided_cuda((5, ), (1, ), torch.float32)
        # Topologically Sorted Source Nodes: [add, mul], Original ATen: [aten.add, aten.mul]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_mul_0.run(arg0_1, arg1_1, buf0, buf1, 5, grid=grid(5), stream=stream0)
        del arg0_1
        del arg1_1
        nvtx.device_range_end(buffer_annotation)
    nvtx.device_range_end(training_annotation)
    return (buf0, buf1, )

@AlexDenisov AlexDenisov force-pushed the alexdenisov/inductor-annotations branch from d137069 to a3404af Compare August 16, 2024 13:55
@aorenste aorenste self-requested a review August 19, 2024 15:37
@AlexDenisov AlexDenisov force-pushed the alexdenisov/inductor-annotations branch from c6d6232 to 42266bc Compare August 20, 2024 08:47
def _codegen(self) -> None:
phase = self.get_training_phase()
if config.annotate_training:
V.graph.wrapper_code.writeline(f"training_annotation = nvtx.device_range_start('{phase}')")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to keep codegen aspects out of scheduler and keep high level scheduler logic lean. This annotation can go here:

if V.graph.graph_inputs:

Similarly, the end can go in _generate.

Comment on lines 3065 to 3066
if config.annotate_buffers:
V.graph.wrapper_code.writeline("nvtx.device_range_end(buffer_annotation)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The buffer annotation, like the profiler, is wrapping a single run of a kernel.

        buffer_annotation = nvtx.device_range_start('op1')
        buf1 = empty_strided_cuda((5, ), (1, ), torch.float32)
        # Topologically Sorted Source Nodes: [mul], Original ATen: [aten.mul]
        triton_poi_fused_mul_1.run(arg0_1, arg1_1, buf1, 5, grid=grid(5), stream=stream0)
        del arg0_1
        del arg1_1
        nvtx.device_range_end(buffer_annotation)

I think it would be an improvement to put the logic here in the same place as the profiler.

As for as CUDA specific - I commented elsewhere about moving the codegen logic elsewhere.

@AlexDenisov AlexDenisov force-pushed the alexdenisov/inductor-annotations branch from 42266bc to 67a0a7d Compare August 27, 2024 08:54
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Sep 4, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@AlexDenisov
Copy link
Contributor Author

@eellison I removed the buffer/kernel annotations leaving only the bare minimum for the training annotations. AMD/ROCm build should also work now.

@eellison eellison self-requested a review September 17, 2024 15:00
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good ! was on pto for a bit

@AlexDenisov AlexDenisov force-pushed the alexdenisov/inductor-annotations branch from 1861b1b to 2a7c750 Compare December 6, 2024 19:52
@AlexDenisov
Copy link
Contributor Author

I've been using wrong linter command (lintrunner -a vs lintrunner -a -m origin/main), now I believe it's fixed.

The other two rocm failures are due to some obscure aws credentials issue 🤷

@AlexDenisov
Copy link
Contributor Author

@albanD @eellison can we do another try, please? I believe the recent failures are due to flakiness, and the linter should be happy now 🤞 🙇

@albanD
Copy link
Collaborator

albanD commented Dec 9, 2024

@pytorchbot merge -r

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch push -f https://github.com/flexaihq/pytorch.git pull/130429/head:alexdenisov/inductor-annotations returned non-zero exit code 128

remote: Permission to flexaihq/pytorch.git denied to pytorchmergebot.
fatal: unable to access 'https://github.com/flexaihq/pytorch.git/': The requested URL returned error: 403

This is likely because the author did not allow edits from maintainers on the PR or because the repo has additional permissions settings that mergebot does not qualify.
Raised by https://github.com/pytorch/pytorch/actions/runs/12244284471

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@albanD
Copy link
Collaborator

albanD commented Dec 9, 2024

Ho the bot couldn't rebase.
It seems CI is in a bad state, you should try to rebase again locally (to trigger fresh ci) and then you can trigger the merge here with the bot command!

@AlexDenisov AlexDenisov force-pushed the alexdenisov/inductor-annotations branch from 2a7c750 to c6dc5e8 Compare December 9, 2024 22:46
@AlexDenisov
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@AlexDenisov
Copy link
Contributor Author

It seems it went through without breakages and reverts (so far), thank you so much for your assistance and for bearing me @eellison @albanD, highly appreciated! 🙌

@AlexDenisov AlexDenisov deleted the alexdenisov/inductor-annotations branch December 11, 2024 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants