@@ -2,7 +2,7 @@ Extending PyTorch
22=================
33
44In this note we'll cover ways of extending :mod: `torch.nn `,
5- :mod: `torch.autograd `, :mod: ` torch `, and writing custom C extensions utilizing our C
5+ :mod: `torch.autograd `, and writing custom C extensions utilizing our C
66libraries.
77
88Extending :mod: `torch.autograd `
@@ -204,285 +204,6 @@ This is how a ``Linear`` module can be implemented::
204204 self.in_features, self.out_features, self.bias is not None
205205 )
206206
207- Extending :mod: `torch `
208- ----------------------
209-
210- You can create custom types that emulate :class: `Tensor ` by defining a custom
211- class with methods that match :class: `Tensor `. But what if you want to be able
212- to pass these types to functions like :func: `torch.add ` in the top-level
213- :mod: `torch ` namespace that accept :class: `Tensor ` operands?
214-
215- If your custom python type defines a method named ``__torch_function__ ``, PyTorch
216- will invoke your ``__torch_function__ `` implementation when an instance of your
217- custom class is passed to a function in the :mod: `torch ` namespace. This makes
218- it possible to define custom implementations for any of the functions in the
219- :mod: `torch ` namespace which your ``__torch_function__ `` implementation can call,
220- allowing your users to make use of your custom type with existing PyTorch
221- workflows that they have already written for :class: `Tensor `. This works with
222- "duck" types that are unrelated to :class: `Tensor ` as well as user-defined
223- subclasses of :class: `Tensor `.
224-
225- Extending :mod: `torch ` with a :class: `Tensor `-like type
226- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
227-
228- .. note :: This functionality is inspired by the NumPy ``__array_function__``
229- protocol. See `the NumPy documentation
230- <https://docs.scipy.org/doc/numpy/user/basics.dispatch.html#basics-dispatch> `_
231- and `NEP-0018
232- <https://numpy.org/neps/nep-0018-array-function-protocol.html> `_ for
233- more details.
234-
235- To make this concrete, let's begin with a simple example that illustrates the
236- API dispatch mechanism. We'll create a custom type that represents a 2D scalar
237- tensor, parametrized by the order ``N `` and value along the diagonal entries,
238- ``value ``::
239-
240- class ScalarTensor(object):
241- def __init__(self, N, value):
242- self._N = N
243- self._value = value
244-
245- def __repr__(self):
246- return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
247-
248- def tensor(self):
249- return self._value * torch.eye(self._N)
250-
251- This first iteration of the design isn't very useful. The main functionality of
252- ``ScalarTensor `` is to provide a more compact string representation of a scalar
253- tensor than in the base tensor class::
254-
255- >>> d = ScalarTensor(5, 2)
256- >>> d
257- ScalarTensor(N=5, value=2)
258- >>> d.tensor()
259- tensor([[2., 0., 0., 0., 0.],
260- [0., 2., 0., 0., 0.],
261- [0., 0., 2., 0., 0.],
262- [0., 0., 0., 2., 0.],
263- [0., 0., 0., 0., 2.]])
264-
265- If we try to use this object with the :mod: `torch ` API, we will run
266- into issues::
267-
268- >>> import torch
269- >>> torch.mean(d)
270- TypeError: mean(): argument 'input' (position 1) must be Tensor, not ScalarTensor
271-
272- Adding a ``__torch_function__ `` implementation to ``ScalarTensor `` makes it
273- possible for the above operation to succeed. Let's re-do our implementation,
274- this time adding a ``__torch_function__ `` implementation::
275-
276- HANDLED_FUNCTIONS = {}
277- class ScalarTensor(object):
278- def __init__(self, N, value):
279- self._N = N
280- self._value = value
281-
282- def __repr__(self):
283- return "DiagonalTensor(N={}, value={})".format(self._N, self._value)
284-
285- def tensor(self):
286- return self._value * torch.eye(self._N)
287-
288- def __torch_function__(self, func, args=(), kwargs=None):
289- if kwargs is None:
290- kwargs = {}
291- if func not in HANDLED_FUNCTIONS:
292- return NotImplemented
293- return HANDLED_FUNCTIONS[func](*args, **kwargs)
294-
295- The ``__torch_function__ `` method takes three arguments: ``func ``, a reference to
296- the torch API function that is being overrided, ``args ``, the tuple of arguments
297- passed to the function, and ``kwargs ``, the dict of keyword arguments passed to
298- the function. It uses a global dispatch stable named ``HANDLED_FUNCTIONS `` to
299- store custom implementations. The keys of this dictionary are functions in the
300- ``torch `` namespace and the values are implementations for ``ScalarTensor ``.
301-
302- .. note :: Using a global dispatch table is not a mandated part of the
303- ``__torch_function__ `` API, it is just a useful design pattern for
304- structuring your override implementations.
305-
306- This class definition isn't quite enough to make ``torch.mean `` do the right
307- thing when we pass it a ``ScalarTensor `` -- we also need to define an
308- implementation for ``torch.mean `` for ``ScalarTensor `` operands and add the
309- implementation to the ``HANDLED_FUNCTIONS `` dispatch table dictionary. One way
310- of doing this is to define a decorator::
311-
312- import functools
313- def implements(torch_function):
314- """Register a torch function override for ScalarTensor"""
315- @functools.wraps(torch_function)
316- def decorator(func):
317- HANDLED_FUNCTIONS[torch_function] = func
318- return func
319- return decorator
320-
321- which can be applied to the implementation of our override::
322-
323- @implements(torch.mean)
324- def mean(input):
325- return float(input._value) / input._N
326-
327- With this change we can now use ``torch.mean `` with ``ScalarTensor ``::
328-
329- >>> d = ScalarTensor(5, 2)
330- >>> torch.mean(d)
331- 0.4
332-
333- Of course ``torch.mean `` is an example of the simplest kind of function to
334- override since it only takes one operand. We can use the same machinery to
335- override a function that takes more than one operand, any one of which might be
336- a tensor or tensor-like that defines ``__torch_function__ ``, for example for
337- :func: `torch.add `::
338-
339- def ensure_tensor(data):
340- if isinstance(data, ScalarTensor):
341- return data.tensor()
342- return torch.as_tensor(data)
343-
344- @implements(torch.add)
345- def add(input, other):
346- try:
347- if input._N == other._N:
348- return ScalarTensor(input._N, input._value + other._value)
349- else:
350- raise ValueError("Shape mismatch!")
351- except AttributeError:
352- return torch.add(ensure_tensor(input), ensure_tensor(other))
353-
354- This version has a fast path for when both operands are ``ScalarTensor ``
355- instances and also a slower path which degrades to converting the data to
356- tensors when either operand is not a ``ScalarTensor ``. That makes the override
357- function correctly when either operand is a ``ScalarTensor `` or a regular
358- :class: `Tensor `::
359-
360- >>> s = ScalarTensor(2, 2)
361- >>> torch.add(s, s)
362- DiagonalTensor(N=2, value=4)
363- >>> t = torch.tensor([[1, 1,], [1, 1]])
364- >>> torch.add(s, t)
365- tensor([[3., 1.],
366- [1., 3.]])
367-
368- Note that our implementation of ``add `` does not take ``alpha `` or ``out `` as
369- keyword arguments like :func: `torch.add ` does::
370-
371- >>> torch.add(s, s, alpha=2)
372- TypeError: add() got an unexpected keyword argument 'alpha'
373-
374- For speed and flexibility the ``__torch_function__ `` dispatch mechanism does not
375- check that the signature of an override function matches the signature of the
376- function being overrided in the :mod: `torch ` API. For some applications ignoring
377- optional arguments would be fine but to ensure full compatibility with
378- :class: `Tensor `, user implementations of torch API functions should take care to
379- exactly emulate the API of the function that is being overrided.
380-
381- Functions in the :mod: `torch ` API that do not have explicit overrides will
382- return ``NotImplemented `` from ``__torch_function__ ``. If all operands with
383- ``__torch_function__ `` defined on them return ``NotImplemented ``, PyTorch will
384- raise a ``TypeError ``. This means that most of the time operations that do not
385- have explicit overrides for a type will raise a ``TypeError `` when an instance
386- of such a type is passed::
387-
388- >>> torch.mul(s, 3)
389- TypeError: no implementation found for 'torch.mul' on types that
390- implement __torch_function__: [ScalarTensor]
391-
392- In practice this means that if you would like to implement your overrides using
393- a ``__torch_function__ `` implementation along these lines, you will need to
394- explicitly implement the full :mod: `torch ` API or the entire subset of the API
395- that you care about for your use case. This may be a tall order as the full
396- :mod: `torch ` API is quite extensive.
397-
398- Another option is to not return ``NotImplemented `` for operations that are not
399- handled but to instead pass a :class: `Tensor ` to the original :mod: `torch `
400- function when no override is available. For example, if we change our
401- implementation of ``__torch_function__ `` for ``ScalarTensor `` to the one below::
402-
403- def __torch_function__(self, func, args=(), kwargs=None):
404- if kwargs is None:
405- kwargs = {}
406- if func not in HANDLED_FUNCTIONS:
407- args = [a.tensor() if hasattr(a, 'tensor') else a for a in args]
408- return func(*args, **kwargs)
409- return HANDLED_FUNCTIONS[func](*args, **kwargs)
410-
411- Then :func: `torch.mul ` will work correctly, although the return type will always
412- be a :class: `Tensor ` rather than a :class: `ScalarTensor `, even if both operands
413- are :class: `ScalarTensor ` instances::
414-
415- >>> s = ScalarTensor(2, 2)
416- >>> torch.mul(s, s)
417- tensor([[4., 0.],
418- [0., 4.]])
419-
420- Also see the ``MetadataTensor `` example below for another variation on this
421- pattern but instead always returns a ``MetadataTensor `` to propagate metadata
422- through operations in the :mod: `torch ` API.
423-
424- Extending :mod: `torch ` with a :class: `Tensor ` wrapper type
425- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
426-
427- Another useful case is a type that wraps a :class: `Tensor `, either as an
428- attribute or via subclassing. Below we implement a special case of this sort of
429- type, a ``MetadataTensor `` that attaches a dictionary of metadata to a
430- :class: `Tensor ` that is propagated through :mod: `torch ` operations. Since this
431- is a generic sort of wrapping for the full :mod: `torch ` API, we do not need to
432- individually implement each override so we can make the ``__torch_function__ ``
433- implementation more permissive about what operations are allowed::
434-
435- class MetadataTensor(object):
436- def __init__(self, data, metadata=None, **kwargs):
437- self._t = torch.as_tensor(data, **kwargs)
438- self._metadata = metadata
439-
440- def __repr__(self):
441- return "Metadata:\n{}\n\ndata:\n{}".format(self._metadata, self._t)
442-
443- def __torch_function__(self, func, args=(), kwargs=None):
444- if kwargs is None:
445- kwargs = {}
446- args = [a._t if hasattr(a, '_t') else a for a in args]
447- ret = func(*args, **kwargs)
448- return MetadataTensor(ret, metadata=self._metadata)
449-
450- This simple implementation won't necessarily work with every function in the
451- :mod: `torch ` API but it is good enough to capture most common operations::
452-
453- >>> metadata = {'owner': 'Ministry of Silly Walks'}
454- >>> m = MetadataTensor([[1, 2], [3, 4]], metadata=metadata)
455- >>> t = torch.tensor([[1, 2], [1, 2]]])
456- >>> torch.add(t, m)
457- Metadata:
458- {'owner': 'Ministry of Silly Walks'}
459-
460- data:
461- tensor([[2, 4],
462- [4, 6]])
463- >>> torch.mul(t, m)
464- Metadata:
465- {'owner': 'Ministry of Silly Walks'}
466-
467- data:
468- tensor([[1, 4],
469- [3, 8]])
470-
471- Operations on multiple types that define ``__torch_function__ ``
472- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
473-
474- It is possible to use the torch API with multiple distinct types that each have
475- a ``__torch_function__ `` implementation, but special care must be taken. In such
476- a case the rules are:
477-
478- * The dispatch operation gathers all distinct implementations of
479- ``__torch_function__ `` for each operand and calls them in order: subclasses
480- before superclasses, and otherwise left to right in the operator expression.
481- * If any value other than ``NotImplemented `` is returned, that value is
482- returned as the result. Implementations can register that they do not
483- implement an operation by returning ``NotImplemented ``.
484- * If all of the ``__torch_function__ `` implementations return
485- ``NotImplemented ``, PyTorch raises a ``TypeError ``.
486207
487208Writing custom C++ extensions
488209-----------------------------
0 commit comments