torch.jit.script¶
- 
torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None)[source]¶
- Scripting a function or - nn.Modulewill inspect the source code, compile it as TorchScript code using the TorchScript compiler, and return a- ScriptModuleor- ScriptFunction. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations. For a complete guide, see the TorchScript Language Reference.- torch.jit.scriptcan be used as a function for modules and functions, and as a decorator- @torch.jit.scriptfor TorchScript Classes and functions.- Parameters
- obj (callable, class, or - nn.Module) – compile.
- Returns
- If - objis- nn.Module,- scriptreturns a- ScriptModuleobject. The returned- ScriptModulewill have the same set of sub-modules and parameters as the original- nn.Module. If- objis a standalone function, a- ScriptFunctionwill be returned.
 - Scripting a function
- The - @torch.jit.scriptdecorator will construct a- ScriptFunctionby compiling the body of the function.- Example (scripting a function): - import torch @torch.jit.script def foo(x, y): if x.max() > y.max(): r = x else: r = y return r print(type(foo)) # torch.jit.ScriptFuncion # See the compiled graph as Python code print(foo.code) # Call the function using the TorchScript interpreter foo(torch.ones(2, 2), torch.ones(2, 2)) 
- Scripting an nn.Module
- Scripting an - nn.Moduleby default will compile the- forwardmethod and recursively compile any methods, submodules, and functions called by- forward. If a- nn.Moduleonly uses features supported in TorchScript, no changes to the original module code should be necessary.- scriptwill construct- ScriptModulethat has copies of the attributes, parameters, and methods of the original module.- Example (scripting a simple module with a Parameter): - import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() # This parameter will be copied to the new ScriptModule self.weight = torch.nn.Parameter(torch.rand(N, M)) # When this submodule is used, it will be compiled self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mv(input) # This calls the `forward` method of the `nn.Linear` module, which will # cause the `self.linear` submodule to be compiled to a `ScriptModule` here output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3)) - Example (scripting a module with traced submodules): - import torch import torch.nn as nn import torch.nn.functional as F class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() # torch.jit.trace produces a ScriptModule's conv1 and conv2 self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16)) self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16)) def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) return input scripted_module = torch.jit.script(MyModule()) - To compile a method other than - forward(and recursively compile anything it calls), add the- @torch.jit.exportdecorator to the method. To opt out of compilation use- @torch.jit.ignoreor- @torch.jit.unused.- Example (an exported and ignored method in a module): - import torch import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() @torch.jit.export def some_entry_point(self, input): return input + 10 @torch.jit.ignore def python_only_fn(self, input): # This function won't be compiled, so any # Python APIs can be used import pdb pdb.set_trace() def forward(self, input): if self.training: self.python_only_fn(input) return input * 99 scripted_module = torch.jit.script(MyModule()) print(scripted_module.some_entry_point(torch.randn(2, 2))) print(scripted_module(torch.randn(2, 2)))