.. currentmodule:: torch.fx torch.fx ============= Overview -------- .. automodule:: torch.fx Writing Transformations ----------------------- TODO Debugging Transformations ------------------------- TODO Limitations of Symbolic Tracing ------------------------------- FX uses a system of **symbolic tracing** (a.k.a `symbolic execution `__) to capture the semantics of programs in a transformable/analyzable form. The system is **tracing** in that it executes the program (really an ``nn.Module`` or function) to gather this information. It is **symbolic** in that the data flowing through the program during this execution is not real data, but rather symbols (“Proxy” in FX parlance). Although symbolic tracing works for most neural net code, it has some limitations. Dynamic Control Flow ^^^^^^^^^^^^^^^^^^^^ The main limitation of symbolic tracing is it does not currently support *dynamic control flow*. That is, loops or ``if`` statements where the condition may depend on the input values of the program. For example, let’s examine the following program: :: def func_to_trace(x): dim0 = x.size[0] if dim0 == 3: return torch.relu(x) else: return torch.neg(x) traced = torch.fx.symbolic_trace(func_to_trace) """ <...> File "dyn.py", line 6, in func_to_trace if dim0 == 3: File "pytorch/torch/fx/proxy.py", line 155, in __bool__ return self.tracer.to_bool(self) File "pytorch/torch/fx/proxy.py", line 85, in to_bool raise TraceError('symbolically traced variables cannot be used as inputs to control flow') torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow """ The condition to the ``if`` statement relies on the value of ``dim0``, which eventually relies on the value of ``x``, a function input. Since ``x`` can change (i.e. if you pass a new input tensor to the traced function), this is *dynamic control flow*. The traceback walks back up through your code to show you where this situation happens. Static Control Flow ~~~~~~~~~~~~~~~~~~~ On the other hand, so-called *static control flow* is supported. Static control flow is loops or ``if`` statements whose value cannot change across invocations. Typically, in PyTorch programs, this control flow arises for code making decisions about a model’s architecture based on hyper-parameters. As a concrete example: :: import torch import torch.fx class MyModule(torch.nn.Module): def __init__(self, do_activation : bool = False): super().__init__() self.do_activation = do_activation self.linear = torch.nn.Linear(512, 512) def forward(self, x): x = self.linear(x) # This if-statement is so-called static control flow. # Its condition does not depend on any input values if self.do_activation: x = torch.relu(x) return x without_activation = MyModule(do_activation=False) with_activation = MyModule(do_activation=True) traced_without_activation = torch.fx.symbolic_trace(without_activation) print(traced_without_activation.code) """ def forward(self, x): linear_1 = self.linear(x); x = None return linear_1 """ traced_with_activation = torch.fx.symbolic_trace(with_activation) print(traced_with_activation.code) """ import torch def forward(self, x): linear_1 = self.linear(x); x = None relu_1 = torch.relu(linear_1); linear_1 = None return relu_1 """ The if-statement ``if self.do_activation`` does not depend on any function inputs, thus it is static. ``do_activation`` can be considered to be a hyper-parameter, and the traces of different instances of ``MyModule`` with different values for that parameter have different code. This is a valid pattern that is supported by symbolic tracing. Many instances of dynamic control flow are semantically static control flow. These instances can be made to support symbolic tracing by removing the data dependencies on input values, for example by moving values to ``Module`` attributes or by passing constant values during symbolic tracing: :: def f(x, flag): if flag: return x else: return x*2 fx.symbolic_trace(f) # Fails! def g(flag): return lambda x: f(x, flag) new_f = g(flag=True) fx.symbolic_trace(new_f) In the case of truly dynamic control flow, the sections of the program that contain this code can be traced as calls to the Method (see :ref:`Customizing Tracing`) or function (see :func:`wrap`) rather than tracing through them. Non-\ ``torch`` Functions ^^^^^^^^^^^^^^^^^^^^^^^^^ FX uses ``__torch_function__`` as the mechanism by which it intercepts calls (see the `technical overview `__ for more information about this). Some functions, such as builtin Python functions or those in the ``math`` module, are things that are not covered by ``__torch_function__``, but we would still like to capture them in symbolic tracing. For example: :: from math import sqrt def normalize(x): """ Normalize `x` by the size of the batch dimension """ return x / sqrt(len(x)) # It's valid Python code normalize(torch.rand(3, 4)) traced = torch.fx.symbolic_trace(normalize) """ <...> File "sqrt.py", line 9, in normalize return x / sqrt(len(x)) File "pytorch/torch/fx/proxy.py", line 161, in __len__ raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope """ The error tells us that the built-in function ``len`` is not supported. We can make it so that functions like this are recorded in the trace as direct calls using the :func:`wrap` API: :: torch.fx.wrap('len') torch.fx.wrap('sqrt') traced = torch.fx.symbolic_trace(normalize) print(traced.code) """ import math def forward(self, x): len_1 = len(x) sqrt_1 = math.sqrt(len_1); len_1 = None truediv = x / sqrt_1; x = sqrt_1 = None return truediv """ .. _Customizing Tracing: Customizing Tracing with the ``Tracer`` class ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The :class:`Tracer` class is the class that underlies the implementation of ``symbolic_trace``. The behavior of tracing can be customized by subclassing Tracer, like so: :: class MyCustomTracer(torch.fx.Tracer): # Inside here you can override various methods # to customize tracing. See the `Tracer` API # reference pass # Let's use this custom tracer to trace through this module class MyModule(torch.nn.Module): def forward(self, x): return torch.relu(x) + torch.ones(3, 4) mod = MyModule() traced_graph = MyCustomTracer().trace(mod) # trace() returns a Graph. Let's wrap it up in a # GraphModule to make it runnable traced = torch.fx.GraphModule(mod, traced_graph) Leaf Modules ~~~~~~~~~~~~ Leaf Modules are the modules that appear as calls in the symbolic trace rather than being traced through. The default set of leaf modules is the set of standard ``torch.nn`` module instances. For example: :: class MySpecialSubmodule(torch.nn.Module): def forward(self, x): return torch.neg(x) class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 4) self.submod = MySpecialSubmodule() def forward(self, x): return self.submod(self.linear(x)) traced = torch.fx.symbolic_trace(MyModule()) print(traced.code) # `linear` is preserved as a call, yet `submod` is traced though. # This is because the default set of "Leaf Modules" includes all # standard `torch.nn` modules. """ import torch def forward(self, x): linear_1 = self.linear(x); x = None neg_1 = torch.neg(linear_1); linear_1 = None return neg_1 """ The set of leaf modules can be customized by overriding :meth:`Tracer.is_leaf_module`. Miscellanea ^^^^^^^^^^^ - Tensor constructors (e.g. ``torch.zeros``, ``torch.ones``, ``torch.rand``, ``torch.randn``, ``torch.sparse_coo_tensor``) are currently not traceable. - The deterministic constructors (``zeros``, ``ones``) can be used and the value they produce will be embedded in the trace as a constant. This is only problematic if the arguments to these constructors refers to dynamic input sizes. In this case, ``ones_like`` or ``zeros_like`` may be a viable substitute. - Nondeterministic constructors (``rand``, ``randn``) will have a single random value embedded in the trace. This is likely not the intended behavior. - This behavior may be fixed in a future release. - Type annotations - Python 3-style type annotations (e.g. ``func(x : torch.Tensor, y : int) -> torch.Tensor``) are supported and will be preserved by symbolic tracing. - Python 2-style comment type annotations ``# type: (torch.Tensor, int) -> torch.Tensor`` are not currently supported. - Annotations on local names within a function are not currently supported. API Reference ------------- .. autofunction:: torch.fx.symbolic_trace .. autofunction:: torch.fx.wrap .. autoclass:: torch.fx.GraphModule :members: .. automethod:: __init__ .. autoclass:: torch.fx.Graph :members: .. automethod:: __init__ .. autoclass:: torch.fx.Node :members: .. autoclass:: torch.fx.Tracer :members: .. autoclass:: torch.fx.Proxy .. autoclass:: torch.fx.Interpreter :members: .. autoclass:: torch.fx.Transformer :members: