Skip to content

Conversation

@wonjoo-wj
Copy link
Collaborator

@wonjoo-wj wonjoo-wj commented Oct 20, 2022

Update codegen to use xla::Shape


This PR updates codegen to use xla::Shape instead of torch::lazy::Shape by overriding the upstream GenLazyNativeFuncDefinition and GenLazyIr generators. It overrides GenLazyNativeFuncDefinition.shape_inference() and GenLazyNativeFuncDefinition.build_ir_node() functions to remove torch::lazy::Shape related shape inference and constructor codegen. It also overrides GenXlaLazyIR.gen() to remove std::vector<torch::lazy::Shape> in Node constructor`.

This depends on an open upstream PyTorch PR pytorch/pytorch#87823 to make GenLazyNativeFuncDefinition customizable.


New LazyIr.h example:

class Frac : public XlaNode {
public:
static torch::lazy::OpKind ClassOpKind() {
  return torch::lazy::OpKind(at::aten::frac);
}

Frac(const torch::lazy::Value& self)
    : XlaNode(torch::lazy::OpKind(at::aten::frac),
              {self},
              [&]() { return FracOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
{
  
}

std::string ToString() const override {
  std::stringstream ss;
  ss << XlaNode::ToString();
  
  return ss.str();
}



bool CanBeReused(const torch::lazy::Value& self) const {
    return false;
    }

torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;




};

Changes

  • Removed std::vector<torch::lazy::Shape> argument to the constructor
  • Updated to call the XlaNode constructor without the torch::lazy::Shape

Old LazyIr.h example: #3930


New XlaNativeFunctions.cpp example:

    at::Tensor XLANativeFunctions::frac(const at::Tensor & self) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Frac>(lazy_self->GetIrValue());
      if (!node) {
          
          node = torch::lazy::MakeNode<Frac>(lazy_self->GetIrValue());
          CacheNode(node);
      }
      
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    }

Changes:

  • Removed torch::lazy::Shape related creation/assertion logic
  • Removed symbolicShapeEnabled logic (need to double check if this is okay)

Old XlaNativeFunctions.cpp example: #3930

@wonjoo-wj wonjoo-wj changed the title [WIP] Update GenXlaLazyIR codegen to use xla::Shape Update GenXlaLazyIR codegen to use xla::Shape Oct 27, 2022
@wonjoo-wj wonjoo-wj changed the title Update GenXlaLazyIR codegen to use xla::Shape Update codegen to use xla::Shape Oct 27, 2022
@wonjoo-wj wonjoo-wj marked this pull request as ready for review October 27, 2022 09:45
@wonjoo-wj wonjoo-wj self-assigned this Oct 27, 2022
@wonjoo-wj wonjoo-wj added REMOVE_TORCH_PIN tracing Lazy Tensor tracing labels Oct 27, 2022
@wonjoo-wj
Copy link
Collaborator Author

@alanwaketan, this should be ready for a review. Thanks!

@wonjoo-wj wonjoo-wj requested a review from alanwaketan October 27, 2022 18:25
@JackCaoG
Copy link
Collaborator

Let's also remove

if (ir_value.node->shapes().size() &&
ir_value.shape().scalar_type() != c10::ScalarType::Undefined) {
logical_element_type = ir_value.shape().scalar_type();

I added it a while back to use lazy shape to construct logical_element _type

/* num_outputs */ {len(schema.returns)},
torch::lazy::MHash({scalar_hashes}))"""

def gen(self, schema: LazyIrSchema) -> List[str]:
Copy link
Collaborator Author

@wonjoo-wj wonjoo-wj Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of this function is same as upstream codegen. Some bits are updated to remove torch::lazy::Shape related logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You did a good job on GenXlaLazyNativeFuncDefinition. We can repeat that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment. It seems to be a bit harder to do make this overriding clean for the GenLazyIR class. The upstream GenLazyNativeFuncDefinition class has specific function named shape_inference() that we can override to remove the torch::lazy::Shape related logic. But for the upstream GenLazyIR class, the shape inference logic is part of a big gen() function so we are unable to override only the specific torch:lazy:Shape related codegen. I've added a comment to this function describing this, let me know if you think this is okay.

Copy link
Collaborator Author

@wonjoo-wj wonjoo-wj Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, accidentally deleted one of your original comments on this thread while I was trying to delete one of my older replies lol. Just for future reference, Jiewen's original comment was:

It looks like maybe we can make the parts we don't need as a method in the GenLazyIR class such that we can substitute them with things that output empty strings for instances.

@wonjoo-wj
Copy link
Collaborator Author

Building cpp tests now succeeding locally, will let CI to verify the remaining tests.

@wonjoo-wj wonjoo-wj requested a review from alanwaketan October 31, 2022 23:14
@wonjoo-wj
Copy link
Collaborator Author

Some unit test failures, looking into it.

@wonjoo-wj
Copy link
Collaborator Author

wonjoo-wj commented Nov 1, 2022

bitwise_scalar tests were failing: TestBitwiseAndScalar, TestBitwiseOrScalar, and TestBitwiseXorScalar due to shape check:

int [4, 2]
-vs-
long [4, 2]

The resulting shape was casted/promoted to s64, causing an assertion failure. Previously when using torch::lazy::Shape, the XLANativeFunctions.cpp populated torch::lazy::Shape::logical_element_type with the torch::lazy::Shape's scalar_type as such:

if (ir_value.node->shapes().size() &&
ir_value.shape().scalar_type() != c10::ScalarType::Undefined) {
logical_element_type = ir_value.shape().scalar_type();

Now since this logic and torch::lazy::Shape are both removed, codegen'ed bitwise ops casted/promoted the resulting shape to long. The last commit moves these ops to ir_codegen only, so we can explicitly call DoBinaryOpWithoutPromo in aten_xla_type.cpp. This does not introduce any new behavior or regression for these ops.

@wonjoo-wj
Copy link
Collaborator Author

@alanwaketan this should be ready for another review, thanks!

@alanwaketan
Copy link
Collaborator

bitwise_scalar tests were failing: TestBitwiseAndScalar, TestBitwiseOrScalar, and TestBitwiseXorScalar due to shape check:

int [4, 2]
-vs-
long [4, 2]

The resulting shape was casted/promoted to s64, causing an assertion failure. Previously when using torch::lazy::Shape, the XLANativeFunctions.cpp populated torch::lazy::Shape::logical_element_type with the torch::lazy::Shape's scalar_type as such:

if (ir_value.node->shapes().size() &&
ir_value.shape().scalar_type() != c10::ScalarType::Undefined) {
logical_element_type = ir_value.shape().scalar_type();

Now since this logic and torch::lazy::Shape are both removed, codegen'ed bitwise ops casted/promoted the resulting shape to long. The last commit moves these ops to ir_codegen only, so we can explicitly call DoBinaryOpWithoutPromo in aten_xla_type.cpp. This does not introduce any new behavior or regression for these ops.

Just for my education, can you point me to where this promotion happens?

@wonjoo-wj
Copy link
Collaborator Author

Just for my education, can you point me to where this promotion happens?

Sure thing! In the current codegen'ed bitwise operators, they call PromotedBinaryOp as shown at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/ops/ops_lower_fn.cpp#L158. This helper function defined at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/helpers.cpp#L557 is responsible the promotion. FYI the PR that codegens bitwise operators is https://github.com/pytorch/xla/pull/3815/files.

@alanwaketan
Copy link
Collaborator

Thanks, @wonjoolee95.

It looks like we don't want to promote scalar tensors. Because of decomposition, all the scalar variants of binary ops will be first converted to tensor version before reaching us (wrap the scalars into scalar tensors). Therefore, we need a way to detect scalar tensors and don't promote them.

On the other hand, it looks like we do want to promote scalars? Or because scalars are annotated differently in xla so they don't get promoted by xla as well.

Can you help me better understand the situation here?

@wonjoo-wj
Copy link
Collaborator Author

The problem here seems be the lack of DoBinaryOpWithoutPromo calls in the codegen'ed bitwise ops. Previous to being codegen'ed, all the bitwise ops were wrapped around DoBinaryOpWithoutPromo for both .scalar and .tensor variations:

at::Tensor XLANativeFunctions::bitwise_or(const at::Tensor& self,
const at::Scalar& other) {
XLA_FN_COUNTER("xla::");
return DoBinaryOpWithoutPromo(
self, other, [&](const XLATensorPtr& xself, const at::Scalar& xother) {
return XLATensor::bitwise_or(xself, xother);
});
}

So the old behavior looks like we don't want to promote scalars or tensors. And when we codegen'ed these ops at https://github.com/pytorch/xla/pull/3815/files, these DoBinaryOpWithoutPromo were omitted when we removed the bitwise ops from aten_xla_type.cpp.

@alanwaketan
Copy link
Collaborator

Yea, if we dig into tensor variant of DoBinaryOpWithoutPromo, what it does is to change the dtype of the tensor if it's a scalar tensor. It does nothing to tensors. Correct me if I'm wrong.

For the scalar variant, it does nothing special. That's why I'm confused. @JackCaoG Do you have any insights?

@wonjoo-wj
Copy link
Collaborator Author

Since the original upstream PyTorch PR pytorch/pytorch#87823 got merged, opened a new one to add the use_lazy_shape flag, which allows us to not override the entire GenLazyIr.gen() function -- pytorch/pytorch#88444.

@alanwaketan
Copy link
Collaborator

Yea, if we dig into tensor variant of DoBinaryOpWithoutPromo, what it does is to change the dtype of the tensor if it's a scalar tensor. It does nothing to tensors. Correct me if I'm wrong.

For the scalar variant, it does nothing special. That's why I'm confused. @JackCaoG Do you have any insights?

Just talked to Jack. It looks like the rule here is just to pass all the tests.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Nov 4, 2022
Add use_lazy_shape flag to GenLazyIr class to allow XLA to use its custom shape class. The default value is kept to use lazy shape, so this PR does not introduce any new behaviors.

PyTorch/XLA companion PR: pytorch/xla#4111
Pull Request resolved: #88444
Approved by: https://github.com/alanwaketan, https://github.com/wconstab
@wonjoo-wj
Copy link
Collaborator Author

Upstream PR is merged, will merge this PR without the torch_pin.

@wonjoo-wj wonjoo-wj merged commit 587a6b6 into master Nov 4, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
Add use_lazy_shape flag to GenLazyIr class to allow XLA to use its custom shape class. The default value is kept to use lazy shape, so this PR does not introduce any new behaviors.

PyTorch/XLA companion PR: pytorch/xla#4111
Pull Request resolved: pytorch#88444
Approved by: https://github.com/alanwaketan, https://github.com/wconstab
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Add use_lazy_shape flag to GenLazyIr class to allow XLA to use its custom shape class. The default value is kept to use lazy shape, so this PR does not introduce any new behaviors.

PyTorch/XLA companion PR: pytorch/xla#4111
Pull Request resolved: pytorch#88444
Approved by: https://github.com/alanwaketan, https://github.com/wconstab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

REMOVE_TORCH_PIN tracing Lazy Tensor tracing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants