Skip to content

Conversation

@henrylhtsang
Copy link
Contributor

@henrylhtsang henrylhtsang commented Oct 23, 2024

Why?

I want the following code to work.

minimal repro:

class M(torch.nn.Module):
    def forward(self, dilate_flag):
        return dilate_flag.item()

input1 = (torch.tensor([1], dtype=torch.bool, device="cuda"),)
model = M().cuda()

ep = torch.export.export(model, input1, strict=True)
path = torch._inductor.aot_compile(ep.module(), input1)
aot_model = torch._export.aot_load(path, device="cuda")
actual_output = aot_model(*input1)

error: AssertionError: Encountered an unsupported object of type <class 'torch.SymBool'> while writing the metadata for exported program

second error will be handled by #138760

Motivation

I could technically bypass it with a torch.int tensor. However, it doesn't work with torch.cond. I want the following to work. It would also require #138760 for aot compile to work.

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.dilate_flag = 0

    def forward(self, dilate_flag):
        self.dilate_flag = dilate_flag.item()

        def true_fn(dilate_flag):
            return dilate_flag.clone()

        def false_fn(dilate_flag):
            return dilate_flag.clone()

        torch.cond(
            self.dilate_flag,
            true_fn,
            false_fn,
            (dilate_flag,),
        )
        return self.dilate_flag

input1 = (torch.tensor([1], dtype=torch.bool, device="cuda"),)
input2 = (torch.tensor([0], dtype=torch.bool, device="cuda"),)
inputs = (input1, input2)
model = M().cuda()

for input in inputs:
    expected_output = model(*input)

    ep = torch.export.export(model, input, strict=False)
    path = torch._inductor.aot_compile(ep.module(), input)
    aot_model = torch._export.aot_load(path, device="cuda")
    actual_output = aot_model(*input)

    assert (
        expected_output == actual_output
    ), f"henry they are not equal {expected_output} != {actual_output}"

Differential Revision: D64867504

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138765

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dbd5906 with merge base 30a83ca (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64867504

@henrylhtsang henrylhtsang added the topic: not user facing topic category label Oct 23, 2024
@henrylhtsang henrylhtsang changed the title [export] support symbool for torch.cond [export] Add support for symbool to make torch.cond usable Oct 23, 2024
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 23, 2024
@henrylhtsang henrylhtsang changed the title [export] Add support for symbool to make torch.cond usable [export] Add support for symbool to make it usable for torch.cond Oct 23, 2024
henrylhtsang added a commit to henrylhtsang/pytorch that referenced this pull request Oct 23, 2024
Summary:

todo

Test Plan: ci

Differential Revision: D64867504
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64867504

henrylhtsang added a commit to henrylhtsang/pytorch that referenced this pull request Oct 24, 2024
Summary:
expect failure right now 

need 
pytorch#138765
pytorch#138760

Differential Revision: D64936442
elif isinstance(val, SymInt):
return SymIntArgument(name=node.name)
elif isinstance(val, SymBool):
return SymIntArgument(name=node.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maye need to add a SymBoolArgument and some test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @angelayi for help on the implementation

henrylhtsang added a commit to henrylhtsang/pytorch that referenced this pull request Oct 25, 2024
Summary:

todo

Test Plan: ci

Differential Revision: D64867504
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64867504

@henrylhtsang
Copy link
Contributor Author

@ydwu4 Added Symbool argument, added tests in export and aot inductor.

henrylhtsang added a commit to henrylhtsang/pytorch that referenced this pull request Oct 28, 2024
Summary:

todo

Test Plan: ci

Differential Revision: D64867504
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64867504

@henrylhtsang
Copy link
Contributor Author

@ydwu4 I am getting the following rst error. Any idea?


Testing of coverage in the sources finished, look at the results in build/coverage/python.txt.
/opt/conda/envs/py_3.9/lib/python3.9/tempfile.py:830: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpjw7iy0ci'>
  _warnings.warn(warn_message, ResourceWarning)
++ wc -l build/coverage/python.txt
++ cut -f1 '-d '
+ lines=7
+ undocumented=5
+ '[' 5 -lt 0 ']'
+ '[' 5 -gt 0 ']'
+ echo undocumented objects found:
undocumented objects found:
+ cat build/coverage/python.txt
Undocumented Python objects
===========================
torch.export.graph_signature
----------------------------
Classes:
 * SymBoolArgument

+ echo 'Make sure you'\''ve updated relevant .rsts in docs/source!'
Make sure you've updated relevant .rsts in docs/source!
+ echo 'You can reproduce locally by running '\''cd docs && make coverage && cat build/coverage/python.txt'\'''
You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'
+ exit 1

@henrylhtsang
Copy link
Contributor Author

@pytorchbot rebase -b viable/strict

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased export-D64867504 onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout export-D64867504 && git pull --rebase)

pytorch-bot bot pushed a commit that referenced this pull request Oct 29, 2024
Summary:

todo

Test Plan: ci

Differential Revision: D64867504
),
"test_size_from_multi_output": fail_stack_allocation(is_skip=True),
"test_torchvision_transforms_functional_tensor_resize": fail_minimal_arrayref_interface(),
# TODO: AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you paste the full error stack trace? cc @desertfire for review on this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stack trace: https://gist.github.com/henrylhtsang/a6ef1a85c56c9fafef1cbc188ca82e67

/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py", line 134, in <genexpr>
    f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>"
  /torch/_inductor/ir.py", line 393, in get_dtype
    return self.dtype
AttributeError: 'ShapeAsConstantBuffer' object has no attribute 'dtype'


for input_spec, node in zip(gs.input_specs, input_node_names):
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument)):
if isinstance(input_spec.arg, (TensorArgument, SymIntArgument, SymBoolArgument)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we have SymBoolArgument inputs? Can add a test for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydwu4 I am not sure, I was just following the pattern of SymIntArgument.

Any SymIntArgument test that I can take a look?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may check test_lifted_constants::Foo

@ydwu4
Copy link
Contributor

ydwu4 commented Oct 29, 2024

the following rst error

Appears to be some doc issue. Not familiar with doc issues. Probably can follow https://github.com/pytorch/pytorch/blob/main/docs/source/export.rst?plain=1#L896 . But not sure why this is not triggered for SymIntArgument. Worth digging deeper.

henrylhtsang added a commit to henrylhtsang/pytorch that referenced this pull request Oct 30, 2024
Summary:

todo

Test Plan: ci

Differential Revision: D64867504
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64867504

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64867504

pytorch-bot bot pushed a commit that referenced this pull request Oct 31, 2024
Summary:
Pull Request resolved: #138765

todo

Test Plan: ci

Differential Revision: D64867504
Copy link
Contributor

@ydwu4 ydwu4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add another test case for test_lifted_constants of boolean?

Summary:

todo

Test Plan: ci

Differential Revision: D64867504
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D64867504

@henrylhtsang
Copy link
Contributor Author

Add another test case for test_lifted_constants of boolean?

my bad, added now

@henrylhtsang
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants