Shortcuts

torch.fx

Overview

This feature is under a Beta release and its API may change.

FX is a toolkit for developers to use to transform nn.Module instances. FX consists of three main components: a symbolic tracer, an intermediate representation, and Python code generation. A demonstration of these components in action:

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

def forward(self, x):
    return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph(x):
    %param : [#users=1] = self.param
    %add_1 : [#users=1] = call_function[target=<built-in function add>](args = (%x, %param), kwargs = {})
    %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %clamp_1 : [#users=1] = call_method[target=clamp](args = (%linear_1,), kwargs = {min: 0.0, max: 1.0})
    return clamp_1
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add_1 = x + param;  x = param = None
    linear_1 = self.linear(add_1);  add_1 = None
    clamp_1 = linear_1.clamp(min = 0.0, max = 1.0);  linear_1 = None
    return clamp_1
"""

The symbolic tracer performs “abstract interpretation” of the Python code. It feeds fake values, called Proxies, through the code. Operations on theses Proxies are recorded. More information about symbolic tracing can be found in the symbolic_trace and Tracer documentation.

The intermediate representation is the container for the operations that were recorded during symbolic tracing. It consists of a list of Nodes that represent function inputs, callsites (to functions, methods, or nn.Module instances), and return values. More information about the IR can be found in the documentation for Graph. The IR is the format on which transformations are applied.

Python code generation is what makes FX a Python-to-Python (or Module-to-Module) transformation toolkit. For each Graph IR, we can create valid Python code matching the Graph’s semantics. This functionality is wrapped up in GraphModule, which is an nn.Module instance that holds a Graph as well as a forward method generated from the Graph.

Taken together, this pipeline of components (symbolic tracing → intermediate representation → transforms → Python code generation) constitutes the Python-to-Python transformation pipeline of FX.

Writing Transformations

TODO

Debugging Transformations

TODO

Limitations of Symbolic Tracing

FX uses a system of symbolic tracing (a.k.a symbolic execution) to capture the semantics of programs in a transformable/analyzable form. The system is tracing in that it executes the program (really an nn.Module or function) to gather this information. It is symbolic in that the data flowing through the program during this execution is not real data, but rather symbols (“Proxy” in FX parlance).

Although symbolic tracing works for most neural net code, it has some limitations.

Dynamic Control Flow

The main limitation of symbolic tracing is it does not currently support dynamic control flow. That is, loops or if statements where the condition may depend on the input values of the program.

For example, let’s examine the following program:

def func_to_trace(x):
    dim0 = x.size[0]
    if dim0 == 3:
        return torch.relu(x)
    else:
        return torch.neg(x)

traced = torch.fx.symbolic_trace(func_to_trace)
"""
  <...>
  File "dyn.py", line 6, in func_to_trace
    if dim0 == 3:
  File "pytorch/torch/fx/proxy.py", line 155, in __bool__
    return self.tracer.to_bool(self)
  File "pytorch/torch/fx/proxy.py", line 85, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow
"""

The condition to the if statement relies on the value of dim0, which eventually relies on the value of x, a function input. Since x can change (i.e. if you pass a new input tensor to the traced function), this is dynamic control flow. The traceback walks back up through your code to show you where this situation happens.

Static Control Flow

On the other hand, so-called static control flow is supported. Static control flow is loops or if statements whose value cannot change across invocations. Typically, in PyTorch programs, this control flow arises for code making decisions about a model’s architecture based on hyper-parameters. As a concrete example:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self, do_activation : bool = False):
        super().__init__()
        self.do_activation = do_activation
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        x = self.linear(x)
        # This if-statement is so-called static control flow.
        # Its condition does not depend on any input values
        if self.do_activation:
            x = torch.relu(x)
        return x

without_activation = MyModule(do_activation=False)
with_activation = MyModule(do_activation=True)

traced_without_activation = torch.fx.symbolic_trace(without_activation)
print(traced_without_activation.code)
"""
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    return linear_1
"""

traced_with_activation = torch.fx.symbolic_trace(with_activation)
print(traced_with_activation.code)
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    relu_1 = torch.relu(linear_1);  linear_1 = None
    return relu_1
"""

The if-statement if self.do_activation does not depend on any function inputs, thus it is static. do_activation can be considered to be a hyper-parameter, and the traces of different instances of MyModule with different values for that parameter have different code. This is a valid pattern that is supported by symbolic tracing.

Many instances of dynamic control flow are semantically static control flow. These instances can be made to support symbolic tracing by removing the data dependencies on input values, for example by moving values to Module attributes or by passing constant values during symbolic tracing:

def f(x, flag):
    if flag: return x
    else: return x*2

fx.symbolic_trace(f) # Fails!

def g(flag):
    return lambda x: f(x, flag)

new_f = g(flag=True)
fx.symbolic_trace(new_f)

In the case of truly dynamic control flow, the sections of the program that contain this code can be traced as calls to the Method (see Customizing Tracing with the Tracer class) or function (see wrap()) rather than tracing through them.

Non-torch Functions

FX uses __torch_function__ as the mechanism by which it intercepts calls (see the technical overview for more information about this). Some functions, such as builtin Python functions or those in the math module, are things that are not covered by __torch_function__, but we would still like to capture them in symbolic tracing. For example:

from math import sqrt

def normalize(x):
    """
    Normalize `x` by the size of the batch dimension
    """
    return x / sqrt(len(x))

# It's valid Python code
normalize(torch.rand(3, 4))

traced = torch.fx.symbolic_trace(normalize)
"""
  <...>
  File "sqrt.py", line 9, in normalize
    return x / sqrt(len(x))
  File "pytorch/torch/fx/proxy.py", line 161, in __len__
    raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope
"""

The error tells us that the built-in function len is not supported. We can make it so that functions like this are recorded in the trace as direct calls using the wrap() API:

torch.fx.wrap('len')
torch.fx.wrap('sqrt')

traced = torch.fx.symbolic_trace(normalize)

print(traced.code)
"""
import math
def forward(self, x):
    len_1 = len(x)
    sqrt_1 = math.sqrt(len_1);  len_1 = None
    truediv = x / sqrt_1;  x = sqrt_1 = None
    return truediv
"""

Customizing Tracing with the Tracer class

The Tracer class is the class that underlies the implementation of symbolic_trace. The behavior of tracing can be customized by subclassing Tracer, like so:

class MyCustomTracer(torch.fx.Tracer):
    # Inside here you can override various methods
    # to customize tracing. See the `Tracer` API
    # reference
    pass


# Let's use this custom tracer to trace through this module
class MyModule(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x) + torch.ones(3, 4)

mod = MyModule()

traced_graph = MyCustomTracer().trace(mod)
# trace() returns a Graph. Let's wrap it up in a
# GraphModule to make it runnable
traced = torch.fx.GraphModule(mod, traced_graph)

Leaf Modules

Leaf Modules are the modules that appear as calls in the symbolic trace rather than being traced through. The default set of leaf modules is the set of standard torch.nn module instances. For example:

class MySpecialSubmodule(torch.nn.Module):
    def forward(self, x):
        return torch.neg(x)

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 4)
        self.submod = MySpecialSubmodule()

    def forward(self, x):
        return self.submod(self.linear(x))

traced = torch.fx.symbolic_trace(MyModule())
print(traced.code)
# `linear` is preserved as a call, yet `submod` is traced though.
# This is because the default set of "Leaf Modules" includes all
# standard `torch.nn` modules.
"""
import torch
def forward(self, x):
    linear_1 = self.linear(x);  x = None
    neg_1 = torch.neg(linear_1);  linear_1 = None
    return neg_1
"""

The set of leaf modules can be customized by overriding Tracer.is_leaf_module().

Miscellanea

  • Tensor constructors (e.g. torch.zeros, torch.ones, torch.rand, torch.randn, torch.sparse_coo_tensor) are currently not traceable.

    • The deterministic constructors (zeros, ones) can be used and the value they produce will be embedded in the trace as a constant. This is only problematic if the arguments to these constructors refers to dynamic input sizes. In this case, ones_like or zeros_like may be a viable substitute.

    • Nondeterministic constructors (rand, randn) will have a single random value embedded in the trace. This is likely not the intended behavior.

    • This behavior may be fixed in a future release.

  • Type annotations

    • Python 3-style type annotations (e.g. func(x : torch.Tensor, y : int) -> torch.Tensor) are supported and will be preserved by symbolic tracing.

    • Python 2-style comment type annotations # type: (torch.Tensor, int) -> torch.Tensor are not currently supported.

    • Annotations on local names within a function are not currently supported.

API Reference

torch.fx.symbolic_trace(root)[source]

Symbolic tracing API

Given an nn.Module or function instance root, this function will return a GraphModule constructed by recording operations seen while tracing through root.

Parameters

root (Union[torch.nn.Module, Callable]) – Module or function to be traced and converted into a Graph representation.

Returns

a Module created from the recorded operations from root.

Return type

GraphModule

torch.fx.wrap(fn_or_name)[source]

This function can be called at module-level scope to register fn_or_name as a “leaf function”. A “leaf function” will be preserved as a CallFunction node in the FX trace instead of being traced through:

# foo/bar/baz.py
def my_custom_function(x, y):
    return x * x + y * y

torch.fx.wrap('my_custom_function')

def fn_to_be_traced(x, y):
    # When symbolic tracing, the below call to my_custom_function will be inserted into
    # the graph rather than tracing it.
    return my_custom_function(x, y)

This function can also equivalently be used as a decorator:

# foo/bar/baz.py
@torch.fx.wrap
def my_custom_function(x, y):
    return x * x + y * y

A wrapped function can be thought of a “leaf function”, analogous to the concept of “leaf modules”, that is, they are functions that are left as calls in the FX trace rather than traced through.

Parameters

fn_or_name (Union[str, Callable]) – The function or name of the global function to insert into the graph when it’s called

class torch.fx.GraphModule(*args, **kwargs)[source]

GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a graph attribute, as well as code and forward attributes generated from that graph.

Warning

When graph is reassigned, code and forward will be automatically regenerated. However, if you edit the contents of the graph without reassigning the graph attribute itself, you must call recompile() to update the generated code.

__init__(root, graph, class_name='GraphModule')[source]

Construct a GraphModule.

Parameters
  • root (Union[torch.nn.Module, Dict[str, Any]) – root can either be an nn.Module instance or a Dict mapping strings to any attribute type. In the case that root is a Module, any references to Module-based objects (via qualified name) in the Graph’s Nodes’ target field will be copied over from the respective place within root’s Module hierarchy into the GraphModule’s module hierarchy. In the case that root is a dict, the qualified name found in a Node’s target will be looked up directly in the dict’s keys. The object mapped to by the Dict will be copied over into the appropriate place within the GraphModule’s module hierarchy.

  • graph (Graph) – graph contains the nodes this GraphModule should use for code generation

  • name (str) – name denotes the name of this GraphModule for debugging purposes. If it’s unset, all error messages will report as originating from GraphModule. It may be helpful to set this to root’s original name or a name that makes sense within the context of your transform.

property code

Return the Python code generated from the Graph underlying this GraphModule.

property graph

Return the Graph underlying this GraphModule

recompile()[source]

Recompile this GraphModule from its graph attribute. This should be called after editing the contained graph, otherwise the generated code of this GraphModule will be out of date.

to_folder(folder, module_name='FxModule')[source]

Dumps out module to folder with module_name so that it can be imported with from <folder> import <module_name>

Parameters
  • folder (Union[str, os.PathLike]) – The folder to write the code out to

  • module_name (str) – Top-level name to use for the Module while writing out the code

class torch.fx.Graph[source]

Graph is the main data structure used in the FX Intermediate Representation. It consists of a series of Node s, each representing callsites (or other syntactic constructs). The list of Node s, taken together, constitute a valid Python function.

For example, the following code

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

Will produce the following Graph:

print(gm.graph)
graph(x):
    %linear_weight : [#users=1] = self.linear.weight
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

For the semantics of operations represented in the Graph, please see Node.

__init__()[source]

Construct an empty Graph.

call_function(the_function, args=None, kwargs=None, type_expr=None)[source]

Insert a call_function Node into the Graph. A call_function node represents a call to a Python callable, specified by the_function. the_function can be

Parameters
  • the_function (Callable[.., Any]) – The function to be called. Can be any PyTorch operator, Python function, or member of the builtins or operator namespaces.

  • args (Optional[Tuple[Argument, ..]]) – The positional arguments to be passed to the called function.

  • kwargs (Optional[Dict[str, Argument]]) – The keyword arguments to be passed to the called function

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns

The newly created and inserted call_function node.

Note

The same insertion point and type expression rules apply for this method as Graph.create_node().

call_method(method_name, args=None, kwargs=None, type_expr=None)[source]

Insert a call_method Node into the Graph. A call_method node represents a call to a given method on the 0th element of args.

Parameters
  • method_name (str) – The name of the method to apply to the self argument. For example, if args[0] is a Node representing a Tensor, then to call relu() on that Tensor, pass relu to method_name.

  • args (Optional[Tuple[Argument, ..]]) – The positional arguments to be passed to the called method. Note that this should include a self argument.

  • kwargs (Optional[Dict[str, Argument]]) – The keyword arguments to be passed to the called method

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns

The newly created and inserted call_method node.

Note

The same insertion point and type expression rules apply for this method as Graph.create_node().

call_module(module_name, args=None, kwargs=None, type_expr=None)[source]

Insert a call_module Node into the Graph. A call_module node represents a call to the forward() function of a Module in the Module hierarchy.

Parameters
  • module_name (str) – The qualified name of the Module in the Module hierarchy to be called. For example, if the traced Module has a submodule named foo, which has a submodule named bar, the qualified name foo.bar should be passed as module_name to call that module.

  • args (Optional[Tuple[Argument, ..]]) – The positional arguments to be passed to the called method. Note that this should not include a self argument.

  • kwargs (Optional[Dict[str, Argument]]) – The keyword arguments to be passed to the called method

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns

The newly-created and inserted call_module node.

Note

The same insertion point and type expression rules apply for this method as Graph.create_node().

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)[source]

Create a Node and add it to the Graph at the current insert-point. Note that the current insert-point can be set via Graph.inserting_before() and Graph.inserting_after().

Parameters
  • op (str) – the opcode for this Node. One of ‘call_function’, ‘call_method’, ‘get_attr’, ‘call_module’, ‘placeholder’, or ‘output’. The semantics of these opcodes are described in the Graph docstring.

  • args (Optional[Tuple[Argument, ..]]) – is a tuple of arguments to this node.

  • kwargs (Optional[Dict[str, Argument]]) – the kwargs of this Node

  • name (Optional[str]) – an optional string name for the Node. This will influence the name of the value assigned to in the Python generated code.

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns

The newly-created and inserted node.

erase_node(to_erase)[source]

Erases a Node from the Graph. Throws an exception if there are still users of that node in the Graph.

Parameters

to_erase (Node) – The Node to erase from the Graph.

get_attr(qualified_name, type_expr=None)[source]

Insert a get_attr node into the Graph. A get_attr Node represents the fetch of an attribute from the Module hierarchy.

Parameters
  • qualified_name (str) – the fully-qualified name of the attribute to be retrieved. For example, if the traced Module has a submodule named foo, which has a submodule named bar, which has an attribute named baz, the qualified name foo.bar.baz should be passed as qualified_name.

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Returns

The newly-created and inserted get_attr node.

Note

The same insertion point and type expression rules apply for this method as Graph.create_node.

graph_copy(g, val_map)[source]

Copy all nodes from a given graph into self.

Parameters
  • g (Graph) – The source graph from which to copy Nodes.

  • val_map (Dict[Node, Node]) – a dictionary that will be populated with a mapping from nodes in g to nodes in self. Note that val_map can be passed in with values in it already to override copying of certain values.

Returns

The value in self that is now equivalent to the output value in g, if g had an output node. None otherwise.

inserting_after(n=None)[source]

Set the point at which create_node and companion methods will insert into the graph. When used within a ‘with’ statement, this will temporary set the insert point and then restore it when the with statement exits:

with g.inserting_after(n):
    ... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) #  set the insert point permanently
Parameters

n (Optional[Node]) – The node before which to insert. If None this will insert after the beginning of the entire graph.

Returns

A resource manager that will restore the insert point on __exit__.

inserting_before(n=None)[source]

Set the point at which create_node and companion methods will insert into the graph. When used within a ‘with’ statement, this will temporary set the insert point and then restore it when the with statement exits:

with g.inserting_before(n):
    ... # inserting before node n
... # insert point restored to what it was previously
g.inserting_before(n) #  set the insert point permanently
Parameters

n (Optional[Node]) – The node before which to insert. If None this will insert before the beginning of the entire graph.

Returns

A resource manager that will restore the insert point on __exit__.

lint(root=None)[source]

Runs various checks on this Graph to make sure it is well-formed. In particular: - Checks Nodes have correct ownership (owned by this graph) - Checks Nodes appear in topological order - If root is provided, checks that targets exist in root

Parameters

root (Optional[torch.nn.Module]) – The root module with which to check for targets. This is equivalent to the root argument that is passed when constructing a GraphModule.

node_copy(node, arg_transform=<function Graph.<lambda>>)[source]

Copy a node from one graph into another. arg_transform needs to transform arguments from the graph of node to the graph of self. Example:

# Copying all the nodes in `g` into `new_graph`
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
Parameters
  • node (Node) – The node to copy into self.

  • arg_transform (Callable[[Node], Argument]) – A function that transforms Node arguments in node’s args and kwargs into the equivalent argument in self. In the simplest case, this should retrieve a value out of a table mapping Nodes in the original graph to self.

property nodes

Get the list of Nodes that constitute this Graph.

Note that this Node list representation is a doubly-linked list. Mutations during iteration (e.g. delete a Node, add a Node) are safe.

Returns

A doubly-linked list of Nodes. Note that reversed can be called on this list to switch iteration order.

output(result, type_expr=None)[source]

Insert an output Node into the Graph. An output node represents a return statement in Python code. result is the value that should be returned.

Parameters
  • result (Argument) – The value to be returned.

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have.

Note

The same insertion point and type expression rules apply for this method as Graph.create_node.

placeholder(name, type_expr=None)[source]

Insert a placeholder node into the Graph. A placeholder represents a function input.

Parameters
  • name (str) – A name for the input value. This corresponds to the name of the positional argument to the function this Graph represents.

  • type_expr (Optional[Any]) – an optional type annotation representing the Python type the output of this node will have. This is needed in some cases for proper code generation (e.g. when the function is used subsequently in TorchScript compilation).

Note

The same insertion point and type expression rules apply for this method as Graph.create_node.

print_tabular()[source]

Prints the intermediate representation of the graph in tabular format.

python_code(root_module)[source]

Turn this Graph into valid Python code.

Parameters

root_module (str) – The name of the root module on which to look-up qualified name targets. This is usually ‘self’.

Returns

The string source code generated from this Graph.

class torch.fx.Node(graph, name, op, target, args, kwargs, type=None)[source]

Node is the data structure that represents individual operations within a Graph. For the most part, Nodes represent callsites to various entities, such as operators, methods, and Modules (some exceptions include nodes that specify function inputs and outputs). Each Node has a function specified by its op property. The Node semantics for each value of op are as follows:

  • placeholder represents a function input. The name attribute specifies the name this value will take on. target is similarly the name of the argument. args holds either: 1) nothing, or 2) a single argument denoting the default parameter of the function input. kwargs is don’t-care. Placeholders correspond to the function parameters (e.g. x) in the graph printout.

  • get_attr retrieves a parameter from the module hierarchy. name is similarly the name the result of the fetch is assigned to. target is the fully-qualified name of the parameter’s position in the module hierarchy. args and kwargs are don’t-care

  • call_function applies a free function to some values. name is similarly the name of the value to assign to. target is the function to be applied. args and kwargs represent the arguments to the function, following the Python calling convention

  • call_module applies a module in the module hierarchy’s forward() method to given arguments. name is as previous. target is the fully-qualified name of the module in the module hierarchy to call. args and kwargs represent the arguments to invoke the module on, including the self argument.

  • call_method calls a method on a value. name is as similar. target is the string name of the method to apply to the self argument. args and kwargs represent the arguments to invoke the module on, including the self argument

  • output contains the output of the traced function in its args[0] attribute. This corresponds to the “return” statement in the Graph printout.

property all_input_nodes

Return all Nodes that are inputs to this Node. This is equivalent to iterating over args and kwargs and only collecting the values that are Nodes.

Returns

List of Nodes that appear in the args and kwargs of this Node, in that order.

append(x)[source]

Insert x after this node in the list of nodes in the graph. Equvalent to self.next.prepend(x)

Parameters

x (Node) – The node to put after this node. Must be a member of the same graph.

property args

The tuple of arguments to this Node. The interpretation of arguments depends on the node’s opcode. See the Node docstring for more information.

Assignment to this property is allowed. All accounting of uses and users is updated automatically on assignment.

property kwargs

The dict of keyword arguments to this Node. The interpretation of arguments depends on the node’s opcode. See the Node docstring for more information.

Assignment to this property is allowed. All accounting of uses and users is updated automatically on assignment.

property next

Returns the next Node in the linked list of Nodes.

Returns

The next Node in the linked list of Nodes.

prepend(x)[source]

Insert x before this node in the list of nodes in the graph. Example:

Before: p -> self
        bx -> x -> ax
After:  p -> x -> self
        bx -> ax
Parameters

x (Node) – The node to put before this node. Must be a member of the same graph.

property prev

Returns the previous Node in the linked list of Nodes.

Returns

The previous Node in the linked list of Nodes.

replace_all_uses_with(replace_with)[source]

Replace all uses of self in the Graph with the Node replace_with.

Parameters

replace_with (Node) – The node to replace all uses of self with.

Returns

The list of Nodes on which this change was made.

class torch.fx.Tracer(autowrap_modules=(<module 'math' from '/opt/anaconda3/envs/pytorch/lib/python3.8/lib-dynload/math.cpython-38-darwin.so'>, ))[source]

Tracer is the class that implements the symbolic tracing functionality of torch.fx.symbolic_trace. A call to symbolic_trace(m) is equivalent to Tracer().trace(m).

Tracer can be subclassed to override various behaviors of the tracing process. The different behaviors that can be overridden are described in the docstrings of the methods on this class.

call_module(m, forward, args, kwargs)[source]

Method that specifies the behavior of this Tracer when it encounters a call to an nn.Module instance.

By default, the behavior is to check if the called module is a leaf module via is_leaf_module. If it is, emit a call_module node referring to m in the Graph. Otherwise, call the Module normally, tracing through the operations in its forward function.

This method can be overridden to–for example–create nested traced GraphModules, or any other behavior you would want while tracing across Module boundaries. Module boundaries.

Parameters
  • m (Module) – The module for which a call is being emitted

  • forward (Callable) – The forward() method of the Module to be invoked

  • args (Tuple) – args of the module callsite

  • kwargs (Dict) – kwargs of the module callsite

Returns

The return value from the Module call. In the case that a call_module node was emitted, this is a Proxy value. Otherwise, it is whatever value was returned from the Module invocation.

create_arg(a)[source]

A method to specify the behavior of tracing when preparing values to be used as arguments to nodes in the Graph.

By default, the behavior includes:

  1. Iterate through collection types (e.g. tuple, list, dict) and recursively call create_args on the elements.

  2. Given a Proxy object, return a reference to the underlying IR Node

  3. Given a non-Proxy Tensor object, emit IR for various cases:

    • For a Parameter, emit a get_attr node referring to that Parameter

    • For a non-Parameter Tensor, store the Tensor away in a special attribute referring to that attribute.

This method can be overridden to support more types.

Parameters

a (Any) – The value to be emitted as an Argument in the Graph.

Returns

The value a converted into the appropriate Argument

create_args_for_root(root_fn, is_module)[source]

Create placeholder nodes corresponding to the signature of the root Module. This method introspects root’s signature and emits those nodes accordingly, also supporting *args and **kwargs.

is_leaf_module(m, module_qualified_name)[source]

A method to specify whether a given nn.Module is a “leaf” module.

Leaf modules are the atomic units that appear in the IR, referenced by call_module calls. By default, Modules in the PyTorch standard library namespace (torch.nn) are leaf modules. All other modules are traced through and their constituent ops are recorded, unless specified otherwise via this parameter.

Parameters
  • m (Module) – The module being queried about

  • module_qualified_name (str) – The path to root of this module. For example, if you have a module hierarchy where submodule foo contains submodule bar, which contains submodule baz, that module will appear with the qualified name foo.bar.baz here.

path_of_module(mod)[source]

Helper method to find the qualified name of mod in the Module hierarchy of root. For example, if root has a submodule named foo, which has a submodule named bar, passing bar into this function will return the string “foo.bar”.

Parameters

mod (str) – The Module to retrieve the qualified name for.

trace(root)[source]

Trace root and return the corresponding FX Graph representation. root can either be an nn.Module instance or a Python callable.

Note that after this call, self.root may be different from the root passed in here. For example, when a free function is passed to trace(), we will create an nn.Module instance to use as the root and add embedded constants to.

Parameters

root (Union[Module, Callable]) – Either a Module or a function to be traced through.

Returns

A Graph representing the semantics of the passed-in root.

class torch.fx.Proxy(node, tracer=None)[source]

Proxy objects are Node wrappers that flow through the program during symbolic tracing and record all the operations (torch function calls, method calls, operators) that they touch into the growing FX Graph.

If you’re doing graph transforms, you can wrap your own Proxy method around a raw Node so that you can use the overloaded operators to add additional things to a Graph.

class torch.fx.Interpreter(module)[source]

An Interpreter executes an FX graph Node-by-Node. This pattern can be useful for many things, including writing code transformations as well as analysis passes.

Methods in the Interpreter class can be overridden to customize the behavior of execution. The map of overrideable methods in terms of call hierarchy:

run()
    +-- run_node
        +-- placeholder()
        +-- get_attr()
        +-- call_function()
        +-- call_method()
        +-- call_module()
        +-- output()

Example

Suppose we want to swap all instances of torch.neg with torch.sigmoid and vice versa (including their Tensor method equivalents). We could subclass Interpreter like so:

class NegSigmSwapInterpreter(Interpreter):
    def call_function(self, target : Target,
                      args : Tuple, kwargs : Dict) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : Target,
                    args : Tuple, kwargs : Dict) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)
input = torch.randn(3, 4)
result = NegSigmSwapInterpreter(gm).run(input)
torch.testing.assert_allclose(result, torch.neg(input).sigmoid())
Parameters

module (GraphModule) – The module to be executed

call_function(target, args, kwargs)[source]

Execute a call_function node and return the result.

Parameters
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Return

Any: The value returned by the function invocation

call_method(target, args, kwargs)[source]

Execute a call_method node and return the result.

Parameters
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Return

Any: The value returned by the method invocation

call_module(target, args, kwargs)[source]

Execute a call_module node and return the result.

Parameters
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Return

Any: The value returned by the module invocation

fetch_args_kwargs_from_env(n)[source]

Fetch the concrete values of args and kwargs of node n from the current execution environment.

Parameters

n (Node) – The node for which args and kwargs should be fetched.

Returns

args and kwargs with concrete values for n.

Return type

Tuple[Tuple, Dict]

fetch_attr(target)[source]

Fetch an attribute from the Module hierarchy of self.module.

Parameters

target (str) – The fully-qualfiied name of the attribute to fetch

Returns

The value of the attribute.

Return type

Any

get_attr(target, args, kwargs)[source]

Execute a get_attr node. Will retrieve an attribute value from the Module hierarchy of self.module.

Parameters
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Returns

The value of the attribute that was retrieved

Return type

Any

map_nodes_to_values(args, n)[source]

Recursively descend through args and look up the concrete value for each Node in the current execution environment.

Parameters
  • args (Argument) – Data structure within which to look up concrete values

  • n (Node) – Node to which args belongs. This is only used for error reporting.

output(target, args, kwargs)[source]

Execute an output node. This really just retrieves the value referenced by the output node and returns it.

Parameters
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Returns

The return value referenced by the output node

Return type

Any

placeholder(target, args, kwargs)[source]

Execute a placeholder node. Note that this is stateful: Interpreter maintains an internal iterator over arguments passed to run and this method returns next() on that iterator.

Parameters
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

Returns

The argument value that was retrieved.

Return type

Any

run(*args, initial_env=None)[source]

Run module via interpretation and return the result.

Parameters
  • *args – The arguments to the Module to run, in positional order

  • initial_env (Optional[Dict[Node, Any]]) – An optional starting environment for execution. This is a dict mapping Node to any value. This can be used, for example, to pre-populate results for certain Nodes so as to do only partial evaluation within the interpreter.

Returns

The value returned from executing the Module

Return type

Any

run_node(n)[source]

Run a specific node n and return the result. Calls into placeholder, get_attr, call_function, call_method, call_module, or output depending on node.op

Parameters

n (Node) – The Node to execute

Returns

The result of executing n

Return type

Any

class torch.fx.Transformer(module)[source]

Transformer is a special type of interpreter that produces a new Module. It exposes a transform() method that returns the transformed Module. Transformer does not require arguments to run, as Interpreter does. Transformer works entirely symbolically.

Example

Suppose we want to swap all instances of torch.neg with torch.sigmoid and vice versa (including their Tensor method equivalents). We could subclass Transformer like so:

class NegSigmSwapXformer(Transformer):
    def call_function(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
        if target == torch.sigmoid:
            return torch.neg(*args, **kwargs)
        return super().call_function(n)

    def call_method(self, target : 'Target', args : Tuple[Any], kwargs : Dict[str, Any]) -> Any:
        if target == 'neg':
            call_self, *args_tail = args
            return call_self.sigmoid(*args_tail, **kwargs)
        return super().call_method(n)

def fn(x):
    return torch.sigmoid(x).neg()

gm = torch.fx.symbolic_trace(fn)

transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
input = torch.randn(3, 4)
torch.testing.assert_allclose(transformed(input), torch.neg(input).sigmoid())
Parameters

module (GraphModule) – The Module to be transformed.

get_attr(target, args, kwargs)[source]

Execute a get_attr node. In Transformer, this is overridden to insert a new get_attr node into the output graph.

Parameters
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

placeholder(target, args, kwargs)[source]

Execute a placeholder node. In Transformer, this is overridden to insert a new placeholder into the output graph.

Parameters
  • target (Target) – The call target for this node. See Node for details on semantics

  • args (Tuple) – Tuple of positional args for this invocation

  • kwargs (Dict) – Dict of keyword arguments for this invocation

transform()[source]

Transform self.module and return the transformed GraphModule.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources