torch.overrides¶
This module exposes various helper functions for the __torch_function__
protocol. See Extending torch for more detail on the
__torch_function__
protocol.
Functions¶
-
torch.overrides.
get_ignored_functions
()[source]¶ Return public functions that cannot be overridden by
__torch_function__
.- Returns
A tuple of functions that are publicly available in the torch API but cannot be overridden with
__torch_function__
. Mostly this is because none of the arguments of these functions are tensors or tensor-likes.- Return type
Set[Callable]
Examples
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
-
torch.overrides.
get_overridable_functions
()[source]¶ List functions that are overridable via __torch_function__
- Returns
A dictionary that maps namespaces that contain overridable functions to functions in that namespace that can be overridden.
- Return type
Dict[Any, List[Callable]]
-
torch.overrides.
get_testing_overrides
()[source]¶ Return a dict containing dummy overrides for all overridable functions
- Returns
A dictionary that maps overridable functions in the PyTorch API to lambda functions that have the same signature as the real function and unconditionally return -1. These lambda functions are useful for testing API coverage for a type that defines
__torch_function__
.- Return type
Dict[Callable, Callable]
Examples
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
-
torch.overrides.
handle_torch_function
(public_api, relevant_args, *args, **kwargs)[source]¶ Implement a function with checks for
__torch_function__
overrides.See torch::autograd::handle_torch_function for the equivalent of this function in the C++ implementation.
- Parameters
public_api (function) – Function exposed by the public torch API originally called like
public_api(*args, **kwargs)
on which arguments are now being checked.relevant_args (iterable) – Iterable of arguments to check for __torch_function__ methods.
args (tuple) – Arbitrary positional arguments originally passed into
public_api
.kwargs (tuple) – Arbitrary keyword arguments originally passed into
public_api
.
- Returns
Result from calling
implementation
or an__torch_function__
method, as appropriate.- Return type
:raises TypeError : if no implementation is found.:
Example
>>> def func(a): ... if type(a) is not torch.Tensor: # This will make func dispatchable by __torch_function__ ... return handle_torch_function(func, (a,), a) ... return a + 0
-
torch.overrides.
has_torch_function
()¶ Check for __torch_function__ implementations in the elements of an iterable. Considers exact
Tensor
s andParameter
s non-dispatchable. :param relevant_args: Iterable or aguments to check for __torch_function__ methods. :type relevant_args: iterable- Returns
True if any of the elements of relevant_args have __torch_function__ implementations, False otherwise.
- Return type
See also
torch.is_tensor_like()
Checks if something is a Tensor-like, including an exact
Tensor
.
-
torch.overrides.
is_tensor_like
(inp)[source]¶ Returns
True
if the passed-in input is a Tensor-like.Currently, this occurs whenever there’s a
__torch_function__
attribute on the type of the input.Examples
A subclass of tensor is generally a Tensor-like.
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
Built-in or user types aren’t usually Tensor-like.
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
But, they can be made Tensor-like by implementing __torch_function__.
>>> class TensorLike: ... def __torch_function__(self, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
-
torch.overrides.
is_tensor_method_or_property
(func)[source]¶ Returns True if the function passed in is a handler for a method or property belonging to
torch.Tensor
, as passed into__torch_function__
.Note
For properties, their
__get__
method must be passed in.This may be needed, in particular, for the following reasons:
Methods/properties sometimes don’t contain a __module__ slot.
They require that the first passed-in argument is an instance of
torch.Tensor
.
Examples
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
-
torch.overrides.
wrap_torch_function
(dispatcher)[source]¶ Wraps a given function with
__torch_function__
-related functionality.- Parameters
dispatcher (Callable) – A callable that returns an iterable of Tensor-likes passed into the function.
Note
This decorator may reduce the performance of your code. Generally, it’s enough to express your code as a series of functions that, themselves, support __torch_function__. If you find yourself in the rare situation where this is not the case, e.g. if you’re wrapping a low-level library and you also need it to work for Tensor-likes, then this function is available.
Examples
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0