Skip to content

Conversation

@albanD
Copy link
Collaborator

@albanD albanD commented Mar 25, 2017

My first cwrap PR, thorough review needed !

Extend cwrap stucture to assign the arguments to variables named arg_* before the function call.
This simplifies the before_call that can now access directly the arguments (the previous behaviour with ${arg0} is still allowed right now for backward compatibility).
Also allows the before_call to modify the arguments!!

Also fixes #146 as the unpacking is done before.

Sample before:

      PyThreadState *_save = NULL;
      try {
        long ndim = ((THPTensor*)self)->cdata->nDimension;
        THPUtils_assert(ndim == 2, "t_() expects a 2D tensor, but self is %ldD", ndim);
        
        Py_UNBLOCK_THREADS;
        THTensor_(transpose)(LIBRARY_STATE ((THPTensor*)self)->cdata, ((THPTensor*)self)->cdata, 0, 1);
        Py_BLOCK_THREADS;
        Py_INCREF(self);
        return (PyObject*)self;
      } catch (...) {
        if (_save) {
          Py_BLOCK_THREADS;
        }
        throw;
      }

Sample after:

      THTensor* arg_self = ((THPTensor*)self)->cdata;
      
      PyThreadState *_save = NULL;
      try {
        long ndim = arg_self->nDimension;
        THPUtils_assert(ndim == 2, "t_() expects a 2D tensor, but self is %ldD", ndim);
        
        Py_UNBLOCK_THREADS;
        THTensor_(transpose)(LIBRARY_STATE arg_self, arg_self, 0, 1);
        Py_BLOCK_THREADS;
        Py_INCREF(self);
        return (PyObject*)self;
      } catch (...) {
        if (_save) {
          Py_BLOCK_THREADS;
        }
        throw;
      }

@albanD albanD force-pushed the cwrap_arg_assign branch from 3e8cd49 to c8ee0a1 Compare March 25, 2017 12:47
Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I only have some suggestions about the naming, but that's it. I only want to take a look about the generated files before merging it.

call_arg = ', '.join(call_arg)
for plugin in self.plugins:
arg_unpack = plugin.process_all_unpacks(arg_unpack, option)
call_arg = plugin.process_all_unpacks(call_arg, option)

This comment was marked as off-topic.

result.append(res)
return result

def build_arg_assign(self, arguments, arg_unpack):

This comment was marked as off-topic.

assignement = []
call_arg = []
# If type names needs to be changed
arguments = self.get_formal_args(arguments)

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Mar 25, 2017

The generated code looks good too!

@apaszke
Copy link
Contributor

apaszke commented Mar 26, 2017

@pytorchbot add to whitelist

@apaszke apaszke merged commit bb71117 into pytorch:master Mar 26, 2017
@apaszke
Copy link
Contributor

apaszke commented Mar 26, 2017

Thank you!

@albanD albanD deleted the cwrap_arg_assign branch March 26, 2017 11:53
jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this pull request Nov 5, 2021
…ions (pytorch#1131)

Fixes pytorch#1102 

This PR implements the second approach mentioned in pytorch#1102 For example, indexing and predicates are changed from:

```
      = T0[(((((nvfuser_index_t)blockIdx.x) * ((nvfuser_index_t)blockDim.y)) + ((nvfuser_index_t)threadIdx.y)) * T0.stride[0])]
```

to:

```
 = T0[(((((nvfuser_index_t)blockIdx.x) * 4) + ((nvfuser_index_t)threadIdx.y)) * T0.stride[0])]
```

The use of `blockDim.y` is replaced by the extent of the second axis of `T0`, which is `4` in this case. This change only matters when a parallel type is not exact (in this case `TIDy`). 

The indexing change only needed to change `getExtent` in index_compute.cpp. However, we also need to predicate `threadIdx` and `blockIdx` to be smaller than IterDomain extents. That's implemented as `ParallelizedDomainPredicate` in predicate_compute.h.
syed-ahmed pushed a commit to syed-ahmed/pytorch that referenced this pull request Sep 22, 2022
* Validate MacOS conda packages and wheels
* Call all matrices using reusable workflow
akashveramd pushed a commit to ROCm/pytorch that referenced this pull request Jun 13, 2025
Implement iRoPE for llama4

Needs the fix from PyTorch:
pytorch#151270
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

THPPlugin calls Python API without GIL

3 participants