-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Description
🚀 Feature
Note, this is one of the features discussed in gh-22402, see that issue for full context.
Implement a __torch_function__ method on Tensor (public) and a torch_function_dispatch decorator (private). The method contains a dispatch mechanism that can be used to dispatch to Tensor-like objects, that can then handle the torch.somefunction call the way they see fit. I.e. torch functions that take at least one tensor input parameter become overridable.
Should be able to do this without any additional overhead when the input is Tensor, and sub-microsecond when it is Tensor-like.
The mechanism is completely analogous to NumPy's __array_ufunc__ and __array_function__ methods.
Motivation
torch functions need to become overridable. One concrete user that @ezyang mentioned is NestedTensor, and there will be others. See also gh-22402.
Plan
First build a prototype and apply it to only a couple of functions in the main torch namespace (to review/evaluate). Make sure they have different signatures from each other. E.g. max, dot, svd.
Use a toy Tensor ducktype class, along the lines of DiagonalArray in https://numpy.org/devdocs/user/basics.dispatch.html, to implement a working example with the functions that are overridden.
- First make it work (parts can be in Python) with the couple of functions chosen
- Then make it fast - reuse the existing checks to ensure zero overhead on
func(Tensor, ...)- needs to all be in C++ - Once that's good, expand coverage to the whole API.