torch.fx¶
Overview¶
This feature is under a Beta release and its API may change.
FX is a toolkit for developers to use to transform nn.Module
instances. FX consists of three main components: a symbolic tracer,
an intermediate representation, and Python code generation. A
demonstration of these components in action:
import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph(x):
%param : [#users=1] = self.param
%add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %param), kwargs = {})
%linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
%clamp_1 : [#users=1] = call_method[target=clamp](args = (%linear_1,), kwargs = {min: 0.0, max: 1.0})
return clamp_1
"""
# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
param = self.param
add_1 = x + param; x = param = None
linear_1 = self.linear(add_1); add_1 = None
clamp_1 = linear_1.clamp(min = 0.0, max = 1.0); linear_1 = None
return clamp_1
"""
The symbolic tracer performs “abstract interpretation” of the Python code. It feeds fake values, called Proxies, through the code. Operations on theses Proxies are recorded. More information about symbolic tracing can be found in the symbolic_trace and Tracer documentation.
The intermediate representation is the container for the operations
that were recorded during symbolic tracing. It consists of a list of
Nodes that represent function inputs, callsites (to functions, methods,
or nn.Module instances), and return values. More information about
the IR can be found in the documentation for
Graph. The
IR is the format on which transformations are applied.
Python code generation is what makes FX a Python-to-Python (or
Module-to-Module) transformation toolkit. For each Graph IR, we can
create valid Python code matching the Graph’s semantics. This
functionality is wrapped up in
GraphModule,
which is an nn.Module instance that holds a Graph as well as a
forward method generated from the Graph.
Taken together, this pipeline of components (symbolic tracing → intermediate representation → transforms → Python code generation) constitutes the Python-to-Python transformation pipeline of FX.
Writing Transformations¶
TODO
Debugging Transformations¶
TODO
Limitations of Symbolic Tracing¶
FX uses a system of symbolic tracing (a.k.a symbolic
execution)
to capture the semantics of programs in a transformable/analyzable form.
The system is tracing in that it executes the program (really an
nn.Module or function) to gather this information. It is
symbolic in that the data flowing through the program during this
execution is not real data, but rather symbols (“Proxy” in FX parlance).
Although symbolic tracing works for most neural net code, it has some limitations.
Dynamic Control Flow¶
The main limitation of symbolic tracing is it does not currently support
dynamic control flow. That is, loops or if statements where the
condition may depend on the input values of the program.
For example, let’s examine the following program:
def func_to_trace(x):
dim0 = x.size[0]
if dim0 == 3:
return torch.relu(x)
else:
return torch.neg(x)
traced = torch.fx.symbolic_trace(func_to_trace)
"""
<...>
File "dyn.py", line 6, in func_to_trace
if dim0 == 3:
File "pytorch/torch/fx/proxy.py", line 155, in __bool__
return self.tracer.to_bool(self)
File "pytorch/torch/fx/proxy.py", line 85, in to_bool
raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""
The condition to the if statement relies on the value of dim0,
which eventually relies on the value of x, a function input. Since
x can change (i.e. if you pass a new input tensor to the traced
function), this is dynamic control flow. The traceback walks back up
through your code to show you where this situation happens.
Static Control Flow¶
On the other hand, so-called static control flow is supported. Static
control flow is loops or if statements whose value cannot change
across invocations. Typically, in PyTorch programs, this control flow
arises for code making decisions about a model’s architecture based on
hyper-parameters. As a concrete example:
import torch
import torch.fx
class MyModule(torch.nn.Module):
def __init__(self, do_activation : bool = False):
super().__init__()
self.do_activation = do_activation
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
x = self.linear(x)
# This if-statement is so-called static control flow.
# Its condition does not depend on any input values
if self.do_activation:
x = torch.relu(x)
return x
without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)
traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
linear_1 = self.linear(x); x = None
return linear_1
"""
traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
relu_1 = torch.relu(linear_1); linear_1 = None
return relu_1
"""
The if-statement if self.do_activation does not depend on any
function inputs, thus it is static. do_activation can be considered
to be a hyper-parameter, and the traces of different instances of
MyModule with different values for that parameter have different
code. This is a valid pattern that is supported by symbolic tracing.
Many instances of dynamic control flow are semantically static control
flow. These instances can be made to support symbolic tracing by
removing the data dependencies on input values, for example by moving
values to Module attributes or by passing constant values during
symbolic tracing:
def f(x, flag):
if flag: return x
else: return x*2
fx.symbolic_trace(f) # Fails!
def g(flag):
return lambda x: f(x, flag)
new_f = g(flag=True)
fx.symbolic_trace(new_f)
In the case of truly dynamic control flow, the sections of the program
that contain this code can be traced as calls to the Method (see
Customizing Tracing with the Tracer class) or function (see
wrap()) rather than tracing through them.
Non-torch Functions¶
FX uses __torch_function__ as the mechanism by which it intercepts
calls (see the technical
overview
for more information about this). Some functions, such as builtin Python
functions or those in the math module, are things that are not
covered by __torch_function__, but we would still like to capture
them in symbolic tracing. For example:
from math import sqrt
def normalize(x):
"""
Normalize `x` by the size of the batch dimension
"""
return x / sqrt(len(x))
# It's valid Python code
normalize(torch.rand(3, 4))
traced = torch.fx.symbolic_trace(normalize)
"""
<...>
File "sqrt.py", line 9, in normalize
return x / sqrt(len(x))
File "pytorch/torch/fx/proxy.py", line 161, in __len__
raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""
The error tells us that the built-in function len is not supported.
We can make it so that functions like this are recorded in the trace as
direct calls using the wrap() API:
torch.fx.wrap('len')
torch.fx.wrap('sqrt')
traced = torch.fx.symbolic_trace(normalize)
print(traced.code)
"""
import math
def forward(self, x):
len_1 = len(x)
sqrt_1 = math.sqrt(len_1); len_1 = None
truediv = x / sqrt_1; x = sqrt_1 = None
return truediv
"""
Customizing Tracing with the Tracer class¶
The Tracer class is the class that underlies the
implementation of symbolic_trace. The behavior of tracing can be
customized by subclassing Tracer, like so:
class MyCustomTracer(torch.fx.Tracer):
# Inside here you can override various methods
# to customize tracing. See the `Tracer` API
# reference
pass
# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.relu(x) + torch.ones(3, 4)
mod = MyModule()
traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)
Leaf Modules¶
Leaf Modules are the modules that appear as calls in the symbolic trace
rather than being traced through. The default set of leaf modules is the
set of standard torch.nn module instances. For example:
class MySpecialSubmodule(torch.nn.Module):
def forward(self, x):
return torch.neg(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 4)
self.submod = MySpecialSubmodule()
def forward(self, x):
return self.submod(self.linear(x))
traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
linear_1 = self.linear(x); x = None
neg_1 = torch.neg(linear_1); linear_1 = None
return neg_1
"""
The set of leaf modules can be customized by overriding
Tracer.is_leaf_module().
Miscellanea¶
Tensor constructors (e.g.
torch.zeros,torch.ones,torch.rand,torch.randn,torch.sparse_coo_tensor) are currently not traceable.The deterministic constructors (
zeros,ones) can be used and the value they produce will be embedded in the trace as a constant. This is only problematic if the arguments to these constructors refers to dynamic input sizes. In this case,ones_likeorzeros_likemay be a viable substitute.Nondeterministic constructors (
rand,randn) will have a single random value embedded in the trace. This is likely not the intended behavior.This behavior may be fixed in a future release.
Type annotations
Python 3-style type annotations (e.g.
func(x : torch.Tensor, y : int) -> torch.Tensor) are supported and will be preserved by symbolic tracing.Python 2-style comment type annotations
# type: (torch.Tensor, int) -> torch.Tensorare not currently supported.Annotations on local names within a function are not currently supported.
API Reference¶
-
torch.fx.symbolic_trace(root)[source]¶ Symbolic tracing API
Given an
nn.Moduleor function instanceroot, this function will return aGraphModuleconstructed by recording operations seen while tracing throughroot.- Parameters
root (Union[torch.nn.Module, Callable]) – Module or function to be traced and converted into a Graph representation.
- Returns
a Module created from the recorded operations from
root.- Return type
-
torch.fx.wrap(fn_or_name)[source]¶ This function can be called at module-level scope to register fn_or_name as a “leaf function”. A “leaf function” will be preserved as a CallFunction node in the FX trace instead of being traced through:
# foo/bar/baz.py def my_custom_function(x, y): return x * x + y * y torch.fx.wrap('my_custom_function') def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into # the graph rather than tracing it. return my_custom_function(x, y)
This function can also equivalently be used as a decorator:
# foo/bar/baz.py @torch.fx.wrap def my_custom_function(x, y): return x * x + y * y
A wrapped function can be thought of a “leaf function”, analogous to the concept of “leaf modules”, that is, they are functions that are left as calls in the FX trace rather than traced through.
- Parameters
fn_or_name (Union[str, Callable]) – The function or name of the global function to insert into the graph when it’s called
-
class
torch.fx.GraphModule(*args, **kwargs)[source]¶ GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
graphattribute, as well ascodeandforwardattributes generated from thatgraph.Warning
When
graphis reassigned,codeandforwardwill be automatically regenerated. However, if you edit the contents of thegraphwithout reassigning thegraphattribute itself, you must callrecompile()to update the generated code.-
__init__(root, graph, class_name='GraphModule')[source]¶ Construct a GraphModule.
- Parameters
root (Union[torch.nn.Module, Dict[str, Any]) –
rootcan either be an nn.Module instance or a Dict mapping strings to any attribute type. In the case thatrootis a Module, any references to Module-based objects (via qualified name) in the Graph’s Nodes’targetfield will be copied over from the respective place withinroot’s Module hierarchy into the GraphModule’s module hierarchy. In the case thatrootis a dict, the qualified name found in a Node’stargetwill be looked up directly in the dict’s keys. The object mapped to by the Dict will be copied over into the appropriate place within the GraphModule’s module hierarchy.graph (Graph) –
graphcontains the nodes this GraphModule should use for code generationname (str) –
namedenotes the name of this GraphModule for debugging purposes. If it’s unset, all error messages will report as originating fromGraphModule. It may be helpful to set this toroot’s original name or a name that makes sense within the context of your transform.
-
property
code¶ Return the Python code generated from the
Graphunderlying thisGraphModule.
-
property
graph¶ Return the
Graphunderlying thisGraphModule
-
recompile()[source]¶ Recompile this GraphModule from its
graphattribute. This should be called after editing the containedgraph, otherwise the generated code of thisGraphModulewill be out of date.
-
to_folder(folder, module_name='FxModule')[source]¶ Dumps out module to
folderwithmodule_nameso that it can be imported withfrom <folder> import <module_name>- Parameters
folder (Union[str, os.PathLike]) – The folder to write the code out to
module_name (str) – Top-level name to use for the
Modulewhile writing out the code
-
-
class
torch.fx.Graph[source]¶ Graphis the main data structure used in the FX Intermediate Representation. It consists of a series ofNodes, each representing callsites (or other syntactic constructs). The list ofNodes, taken together, constitute a valid Python function.For example, the following code
import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.param = torch.nn.Parameter(torch.rand(3, 4)) self.linear = torch.nn.Linear(4, 5) def forward(self, x): return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) m = MyModule() gm = torch.fx.symbolic_trace(m)
Will produce the following Graph:
print(gm.graph)
graph(x): %linear_weight : [#users=1] = self.linear.weight %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) return topk_1For the semantics of operations represented in the
Graph, please seeNode.-
call_function(the_function, args=None, kwargs=None, type_expr=None)[source]¶ Insert a
call_functionNodeinto theGraph. Acall_functionnode represents a call to a Python callable, specified bythe_function.the_functioncan be- Parameters
the_function (Callable[.., Any]) – The function to be called. Can be any PyTorch operator, Python function, or member of the
builtinsoroperatornamespaces.args (Optional[Tuple[Argument, ..]]) – The positional arguments to be passed to the called function.
kwargs (Optional[Dict[str, Argument]]) – The keyword arguments to be passed to the called function
type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.
Returns
The newly created and inserted
call_functionnode.Note
The same insertion point and type expression rules apply for this method as
Graph.create_node().
-
call_method(method_name, args=None, kwargs=None, type_expr=None)[source]¶ Insert a
call_methodNodeinto theGraph. Acall_methodnode represents a call to a given method on the 0th element ofargs.- Parameters
method_name (str) – The name of the method to apply to the self argument. For example, if args[0] is a
Noderepresenting aTensor, then to callrelu()on thatTensor, passrelutomethod_name.args (Optional[Tuple[Argument, ..]]) – The positional arguments to be passed to the called method. Note that this should include a
selfargument.kwargs (Optional[Dict[str, Argument]]) – The keyword arguments to be passed to the called method
type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.
- Returns
The newly created and inserted
call_methodnode.
Note
The same insertion point and type expression rules apply for this method as
Graph.create_node().
-
call_module(module_name, args=None, kwargs=None, type_expr=None)[source]¶ Insert a
call_moduleNodeinto theGraph. Acall_modulenode represents a call to the forward() function of aModulein theModulehierarchy.- Parameters
module_name (str) – The qualified name of the
Modulein theModulehierarchy to be called. For example, if the tracedModulehas a submodule namedfoo, which has a submodule namedbar, the qualified namefoo.barshould be passed asmodule_nameto call that module.args (Optional[Tuple[Argument, ..]]) – The positional arguments to be passed to the called method. Note that this should not include a
selfargument.kwargs (Optional[Dict[str, Argument]]) – The keyword arguments to be passed to the called method
type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.
- Returns
The newly-created and inserted
call_modulenode.
Note
The same insertion point and type expression rules apply for this method as
Graph.create_node().
-
create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]¶ Create a
Nodeand add it to theGraphat the current insert-point. Note that the current insert-point can be set viaGraph.inserting_before()andGraph.inserting_after().- Parameters
op (str) – the opcode for this Node. One of ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’, or ‘output’. The semantics of these opcodes are described in the
Graphdocstring.args (Optional[Tuple[Argument, ..]]) – is a tuple of arguments to this node.
kwargs (Optional[Dict[str, Argument]]) – the kwargs of this Node
name (Optional[str]) – an optional string name for the
Node. This will influence the name of the value assigned to in the Python generated code.type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.
- Returns
The newly-created and inserted node.
-
erase_node(to_erase)[source]¶ Erases a
Nodefrom theGraph. Throws an exception if there are still users of that node in theGraph.- Parameters
to_erase (Node) – The
Nodeto erase from theGraph.
-
get_attr(qualified_name, type_expr=None)[source]¶ Insert a
get_attrnode into the Graph. Aget_attrNoderepresents the fetch of an attribute from theModulehierarchy.- Parameters
qualified_name (str) – the fully-qualified name of the attribute to be retrieved. For example, if the traced Module has a submodule named
foo, which has a submodule namedbar, which has an attribute namedbaz, the qualified namefoo.bar.bazshould be passed asqualified_name.type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.
- Returns
The newly-created and inserted
get_attrnode.
Note
The same insertion point and type expression rules apply for this method as
Graph.create_node.
-
graph_copy(g, val_map)[source]¶ Copy all nodes from a given graph into
self.- Parameters
- Returns
The value in
selfthat is now equivalent to the output value ing, ifghad anoutputnode.Noneotherwise.
-
inserting_after(n=None)[source]¶ Set the point at which create_node and companion methods will insert into the graph. When used within a ‘with’ statement, this will temporary set the insert point and then restore it when the with statement exits:
with g.inserting_after(n): ... # inserting after node n ... # insert point restored to what it was previously g.inserting_after(n) # set the insert point permanently
- Parameters
n (Optional[Node]) – The node before which to insert. If None this will insert after the beginning of the entire graph.
- Returns
A resource manager that will restore the insert point on
__exit__.
-
inserting_before(n=None)[source]¶ Set the point at which create_node and companion methods will insert into the graph. When used within a ‘with’ statement, this will temporary set the insert point and then restore it when the with statement exits:
with g.inserting_before(n): ... # inserting before node n ... # insert point restored to what it was previously g.inserting_before(n) # set the insert point permanently
- Parameters
n (Optional[Node]) – The node before which to insert. If None this will insert before the beginning of the entire graph.
- Returns
A resource manager that will restore the insert point on
__exit__.
-
lint(root=None)[source]¶ Runs various checks on this Graph to make sure it is well-formed. In particular: - Checks Nodes have correct ownership (owned by this graph) - Checks Nodes appear in topological order - If
rootis provided, checks that targets exist inroot- Parameters
root (Optional[torch.nn.Module]) – The root module with which to check for targets. This is equivalent to the
rootargument that is passed when constructing aGraphModule.
-
node_copy(node, arg_transform=<function Graph.<lambda>>)[source]¶ Copy a node from one graph into another.
arg_transformneeds to transform arguments from the graph of node to the graph of self. Example:# Copying all the nodes in `g` into `new_graph` g : torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
- Parameters
node (Node) – The node to copy into
self.arg_transform (Callable[[Node], Argument]) – A function that transforms
Nodearguments in node’sargsandkwargsinto the equivalent argument inself. In the simplest case, this should retrieve a value out of a table mapping Nodes in the original graph toself.
-
property
nodes¶ Get the list of Nodes that constitute this Graph.
Note that this
Nodelist representation is a doubly-linked list. Mutations during iteration (e.g. delete a Node, add a Node) are safe.- Returns
A doubly-linked list of Nodes. Note that
reversedcan be called on this list to switch iteration order.
-
output(result, type_expr=None)[source]¶ Insert an
outputNodeinto theGraph. Anoutputnode represents areturnstatement in Python code.resultis the value that should be returned.- Parameters
result (Argument) – The value to be returned.
type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.
Note
The same insertion point and type expression rules apply for this method as
Graph.create_node.
-
placeholder(name, type_expr=None)[source]¶ Insert a
placeholdernode into the Graph. Aplaceholderrepresents a function input.- Parameters
name (str) – A name for the input value. This corresponds to the name of the positional argument to the function this
Graphrepresents.type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have. This is needed in some cases for proper code generation (e.g. when the function is used subsequently in TorchScript compilation).
Note
The same insertion point and type expression rules apply for this method as
Graph.create_node.
-
-
class
torch.fx.Node(graph, name, op, target, args, kwargs, type=None)[source]¶ Nodeis the data structure that represents individual operations within aGraph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include nodes that specify function inputs and outputs). EachNodehas a function specified by itsopproperty. TheNodesemantics for each value ofopare as follows:placeholderrepresents a function input. Thenameattribute specifies the name this value will take on.targetis similarly the name of the argument.argsholds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input.kwargsis don’t-care. Placeholders correspond to the function parameters (e.g.x) in the graph printout.get_attrretrieves a parameter from the module hierarchy.nameis similarly the name the result of the fetch is assigned to.targetis the fully-qualified name of the parameter’s position in the module hierarchy.argsandkwargsare don’t-carecall_functionapplies a free function to some values.nameis similarly the name of the value to assign to.targetis the function to be applied.argsandkwargsrepresent the arguments to the function, following the Python calling conventioncall_moduleapplies a module in the module hierarchy’sforward()method to given arguments.nameis as previous.targetis the fully-qualified name of the module in the module hierarchy to call.argsandkwargsrepresent the arguments to invoke the module on, including the self argument.call_methodcalls a method on a value.nameis as similar.targetis the string name of the method to apply to theselfargument.argsandkwargsrepresent the arguments to invoke the module on, including the self argumentoutputcontains the output of the traced function in itsargs[0]attribute. This corresponds to the “return” statement in the Graph printout.
-
property
all_input_nodes¶ Return all Nodes that are inputs to this Node. This is equivalent to iterating over
argsandkwargsand only collecting the values that are Nodes.- Returns
List of
Nodesthat appear in theargsandkwargsof thisNode, in that order.
-
append(x)[source]¶ Insert x after this node in the list of nodes in the graph. Equvalent to
self.next.prepend(x)- Parameters
x (Node) – The node to put after this node. Must be a member of the same graph.
-
property
args¶ The tuple of arguments to this
Node. The interpretation of arguments depends on the node’s opcode. See theNodedocstring for more information.Assignment to this property is allowed. All accounting of uses and users is updated automatically on assignment.
-
property
kwargs¶ The dict of keyword arguments to this
Node. The interpretation of arguments depends on the node’s opcode. See theNodedocstring for more information.Assignment to this property is allowed. All accounting of uses and users is updated automatically on assignment.
-
property
next¶ Returns the next
Nodein the linked list of Nodes.- Returns
The next
Nodein the linked list of Nodes.
-
prepend(x)[source]¶ Insert x before this node in the list of nodes in the graph. Example:
Before: p -> self bx -> x -> ax After: p -> x -> self bx -> ax
- Parameters
x (Node) – The node to put before this node. Must be a member of the same graph.
-
property
prev¶ Returns the previous
Nodein the linked list of Nodes.- Returns
The previous
Nodein the linked list of Nodes.
-
class
torch.fx.Tracer(autowrap_modules=(<module 'math' from '/opt/anaconda3/envs/pytorch/lib/python3.8/lib-dynload/math.cpython-38-darwin.so'>, ))[source]¶ Traceris the class that implements the symbolic tracing functionality oftorch.fx.symbolic_trace. A call tosymbolic_trace(m)is equivalent toTracer().trace(m).Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described in the docstrings of the methods on this class.
-
call_module(m, forward, args, kwargs)[source]¶ Method that specifies the behavior of this
Tracerwhen it encounters a call to annn.Moduleinstance.By default, the behavior is to check if the called module is a leaf module via
is_leaf_module. If it is, emit acall_modulenode referring tomin theGraph. Otherwise, call theModulenormally, tracing through the operations in itsforwardfunction.This method can be overridden to–for example–create nested traced GraphModules, or any other behavior you would want while tracing across
Moduleboundaries.Moduleboundaries.- Parameters
m (Module) – The module for which a call is being emitted
forward (Callable) – The forward() method of the
Moduleto be invokedargs (Tuple) – args of the module callsite
kwargs (Dict) – kwargs of the module callsite
- Returns
The return value from the Module call. In the case that a
call_modulenode was emitted, this is aProxyvalue. Otherwise, it is whatever value was returned from theModuleinvocation.
-
create_arg(a)[source]¶ A method to specify the behavior of tracing when preparing values to be used as arguments to nodes in the
Graph.By default, the behavior includes:
Iterate through collection types (e.g. tuple, list, dict) and recursively call
create_argson the elements.Given a Proxy object, return a reference to the underlying IR
NodeGiven a non-Proxy Tensor object, emit IR for various cases:
For a Parameter, emit a
get_attrnode referring to that ParameterFor a non-Parameter Tensor, store the Tensor away in a special attribute referring to that attribute.
This method can be overridden to support more types.
- Parameters
a (Any) – The value to be emitted as an
Argumentin theGraph.- Returns
The value
aconverted into the appropriateArgument
-
create_args_for_root(root_fn, is_module)[source]¶ Create
placeholdernodes corresponding to the signature of therootModule. This method introspects root’s signature and emits those nodes accordingly, also supporting*argsand**kwargs.
-
is_leaf_module(m, module_qualified_name)[source]¶ A method to specify whether a given
nn.Moduleis a “leaf” module.Leaf modules are the atomic units that appear in the IR, referenced by
call_modulecalls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.- Parameters
-
path_of_module(mod)[source]¶ Helper method to find the qualified name of
modin the Module hierarchy ofroot. For example, ifroothas a submodule namedfoo, which has a submodule namedbar, passingbarinto this function will return the string “foo.bar”.- Parameters
mod (str) – The
Moduleto retrieve the qualified name for.
-
trace(root)[source]¶ Trace
rootand return the corresponding FXGraphrepresentation.rootcan either be annn.Moduleinstance or a Python callable.Note that after this call,
self.rootmay be different from therootpassed in here. For example, when a free function is passed totrace(), we will create annn.Moduleinstance to use as the root and add embedded constants to.- Parameters
root (Union[Module, Callable]) – Either a
Moduleor a function to be traced through.- Returns
A
Graphrepresenting the semantics of the passed-inroot.
-
-
class
torch.fx.Proxy(node, tracer=None)[source]¶ Proxyobjects areNodewrappers that flow through the program during symbolic tracing and record all the operations (torchfunction calls, method calls, operators) that they touch into the growing FX Graph.If you’re doing graph transforms, you can wrap your own
Proxymethod around a rawNodeso that you can use the overloaded operators to add additional things to aGraph.
-
class
torch.fx.Interpreter(module)[source]¶ An Interpreter executes an FX graph Node-by-Node. This pattern can be useful for many things, including writing code transformations as well as analysis passes.
Methods in the Interpreter class can be overridden to customize the behavior of execution. The map of overrideable methods in terms of call hierarchy:
run() +-- run_node +-- placeholder() +-- get_attr() +-- call_function() +-- call_method() +-- call_module() +-- output()
Example
Suppose we want to swap all instances of
torch.negwithtorch.sigmoidand vice versa (including theirTensormethod equivalents). We could subclass Interpreter like so:class NegSigmSwapInterpreter(Interpreter): def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) torch.testing.assert_allclose(result, torch.neg(input).sigmoid())
- Parameters
module (GraphModule) – The module to be executed
-
call_function(target, args, kwargs)[source]¶ Execute a
call_functionnode and return the result.- Parameters
target (Target) – The call target for this node. See Node for details on semantics
args (Tuple) – Tuple of positional args for this invocation
kwargs (Dict) – Dict of keyword arguments for this invocation
- Return
Any: The value returned by the function invocation
-
call_method(target, args, kwargs)[source]¶ Execute a
call_methodnode and return the result.- Parameters
target (Target) – The call target for this node. See Node for details on semantics
args (Tuple) – Tuple of positional args for this invocation
kwargs (Dict) – Dict of keyword arguments for this invocation
- Return
Any: The value returned by the method invocation
-
call_module(target, args, kwargs)[source]¶ Execute a
call_modulenode and return the result.- Parameters
target (Target) – The call target for this node. See Node for details on semantics
args (Tuple) – Tuple of positional args for this invocation
kwargs (Dict) – Dict of keyword arguments for this invocation
- Return
Any: The value returned by the module invocation
-
fetch_args_kwargs_from_env(n)[source]¶ Fetch the concrete values of
argsandkwargsof nodenfrom the current execution environment.- Parameters
n (Node) – The node for which
argsandkwargsshould be fetched.- Returns
argsandkwargswith concrete values forn.- Return type
Tuple[Tuple, Dict]
-
fetch_attr(target)[source]¶ Fetch an attribute from the
Modulehierarchy ofself.module.- Parameters
target (str) – The fully-qualfiied name of the attribute to fetch
- Returns
The value of the attribute.
- Return type
Any
-
get_attr(target, args, kwargs)[source]¶ Execute a
get_attrnode. Will retrieve an attribute value from theModulehierarchy ofself.module.- Parameters
target (Target) – The call target for this node. See Node for details on semantics
args (Tuple) – Tuple of positional args for this invocation
kwargs (Dict) – Dict of keyword arguments for this invocation
- Returns
The value of the attribute that was retrieved
- Return type
Any
-
map_nodes_to_values(args, n)[source]¶ Recursively descend through
argsand look up the concrete value for eachNodein the current execution environment.- Parameters
args (Argument) – Data structure within which to look up concrete values
n (Node) – Node to which
argsbelongs. This is only used for error reporting.
-
output(target, args, kwargs)[source]¶ Execute an
outputnode. This really just retrieves the value referenced by theoutputnode and returns it.- Parameters
target (Target) – The call target for this node. See Node for details on semantics
args (Tuple) – Tuple of positional args for this invocation
kwargs (Dict) – Dict of keyword arguments for this invocation
- Returns
The return value referenced by the output node
- Return type
Any
-
placeholder(target, args, kwargs)[source]¶ Execute a
placeholdernode. Note that this is stateful:Interpretermaintains an internal iterator over arguments passed torunand this method returns next() on that iterator.- Parameters
target (Target) – The call target for this node. See Node for details on semantics
args (Tuple) – Tuple of positional args for this invocation
kwargs (Dict) – Dict of keyword arguments for this invocation
- Returns
The argument value that was retrieved.
- Return type
Any
-
run(*args, initial_env=None)[source]¶ Run module via interpretation and return the result.
- Parameters
*args – The arguments to the Module to run, in positional order
initial_env (Optional[Dict[Node, Any]]) – An optional starting environment for execution. This is a dict mapping Node to any value. This can be used, for example, to pre-populate results for certain Nodes so as to do only partial evaluation within the interpreter.
- Returns
The value returned from executing the Module
- Return type
Any
-
class
torch.fx.Transformer(module)[source]¶ Transformeris a special type of interpreter that produces a newModule. It exposes atransform()method that returns the transformedModule.Transformerdoes not require arguments to run, asInterpreterdoes.Transformerworks entirely symbolically.Example
Suppose we want to swap all instances of
torch.negwithtorch.sigmoidand vice versa (including theirTensormethod equivalents). We could subclassTransformerlike so:class NegSigmSwapXformer(Transformer): def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any: if target == 'neg': call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) def fn(x): return torch.sigmoid(x).neg() gm = torch.fx.symbolic_trace(fn) transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_allclose(transformed(input), torch.neg(input).sigmoid())
- Parameters
module (GraphModule) – The
Moduleto be transformed.
-
get_attr(target, args, kwargs)[source]¶ Execute a
get_attrnode. InTransformer, this is overridden to insert a newget_attrnode into the output graph.- Parameters
target (Target) – The call target for this node. See Node for details on semantics
args (Tuple) – Tuple of positional args for this invocation
kwargs (Dict) – Dict of keyword arguments for this invocation
-
placeholder(target, args, kwargs)[source]¶ Execute a
placeholdernode. InTransformer, this is overridden to insert a newplaceholderinto the output graph.- Parameters
target (Target) – The call target for this node. See Node for details on semantics
args (Tuple) – Tuple of positional args for this invocation
kwargs (Dict) – Dict of keyword arguments for this invocation