Shortcuts

Source code for torch.jit._freeze

"""Freezing

This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""

from typing import Optional, List

import torch
from torch.jit._script import RecursiveScriptModule, ScriptModule


[docs]def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True): r""" Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. By default, `forward` will be preserved, as well as attributes & methods specified in `preserved_attrs`. Additionally, any attribute that is modified within a preserved method will be preserved. Freezing currently only accepts ScriptModules that are in eval mode. Args: mod (:class:`ScriptModule`): a module to be frozen preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. Attributes modified in preserved methods will also be preserved. optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference, in addition to the graph cleanup that already occurs. The details of the optimizations can be found in `torch.jit.optimize_frozen_module.` Returns: Frozen :class:`ScriptModule`. Example (Freezing a simple module with a Parameter): .. testcode:: import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mm(input) output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3).eval()) frozen_module = torch.jit.freeze(scripted_module) # parameters have been removed and inlined into the Graph as constants assert len(list(frozen_module.named_parameters())) == 0 # See the compiled graph as Python code print(frozen_module.code) Example (Freezing a module with preserved attributes) .. testcode:: import torch class MyModule2(torch.nn.Module): def __init__(self): super(MyModule2, self).__init__() self.modified_tensor = torch.tensor(10.) self.version = 1 def forward(self, input): self.modified_tensor += 1 return input + self.modified_tensor scripted_module = torch.jit.script(MyModule2().eval()) frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) # we've manually preserved `version`, so it still exists on the frozen module and can be modified assert frozen_module.version == 1 frozen_module.version = 2 # `modified_tensor` is detected as being mutated in the forward, so freezing preserves # it to retain model semantics assert frozen_module(torch.tensor(1)) == torch.tensor(12) # now that we've run it once, the next result will be incremented by one assert frozen_module(torch.tensor(1)) == torch.tensor(13) Note: If you're not sure why an attribute is not being inlined as a constant, you can run `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the attribute is being modified. """ if not isinstance(mod, ScriptModule): raise RuntimeError( "Freezing expects a ScriptModule as input. " "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." ) if mod.training: raise RuntimeError( "Freezing is currently only implemented for modules in eval mode. " "Please call .eval() on your module before freezing." ) preserved_attrs = preserved_attrs if preserved_attrs is not None else [] out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) RecursiveScriptModule._finalize_scriptmodule(out) if optimize: optimize_frozen_module(out) return out
def optimize_frozen_module(mod): r""" Runs a series of optimizations looking for patterns that occur in frozen graphs. The current set of optimizations is: - Conv -> Batchnorm folding - Conv -> Add/Sub folding - Conv -> Mul/Div folding Args: mod (:class:`ScriptModule`): a frozen module to be optimized Returns: None Note: In rare occassions, this can result in slower execution. Example (Freezing a module with Conv->Batchnorm) .. code-block:: python import torch in_channels, out_channels = 3, 32 conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) bn = torch.nn.BatchNorm2d(out_channels, eps=.001) mod = torch.nn.Sequential(conv, bn) # set optimize to False here, by default freezing runs optimize_frozen_module frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False) # inspect frozen mod assert "batch_norm" in str(frozen_mod.graph) torch.jit.optimize_frozen_module(frozen_mod) assert "batch_norm" not in str(frozen_mod.graph) """ torch._C._jit_pass_optimize_frozen_graph(mod.graph)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources