Note
Click here to download the full example code
Introduction to TorchScript¶
Created On: Aug 09, 2019 | Last Updated: Dec 02, 2024 | Last Verified: Nov 05, 2024
Authors: James Reed (jamesreed@fb.com), Michael Suo (suo@fb.com), rev2
Warning
TorchScript is no longer in active development.
This tutorial is an introduction to TorchScript, an intermediate representation of a PyTorch model (subclass of nn.Module
) that can then be run in a high-performance environment such as C++.
In this tutorial we will cover:
The basics of model authoring in PyTorch, including:
Modules
Defining
forward
functionsComposing modules into a hierarchy of modules
Specific methods for converting PyTorch modules to TorchScript, our high-performance deployment runtime
Tracing an existing module
Using scripting to directly compile a module
How to compose both approaches
Saving and loading TorchScript modules
We hope that after you complete this tutorial, you will proceed to go through the follow-on tutorial which will walk you through an example of actually calling a TorchScript model from C++.
importtorch# This is all you need to use both PyTorch and TorchScript!print(torch.__version__)torch.manual_seed(191009)# set the seed for reproducibility
2.7.0+cu126 <torch._C.Generator object at 0x7fec1599a470>
Basics of PyTorch Model Authoring¶
Let’s start out by defining a simple Module
. A Module
is the basic unit of composition in PyTorch. It contains:
A constructor, which prepares the module for invocation
A set of
Parameters
and sub-Modules
. These are initialized by the constructor and can be used by the module during invocation.A
forward
function. This is the code that is run when the module is invoked.
Let’s examine a small example:
classMyCell(torch.nn.Module):def__init__(self):super(MyCell,self).__init__()defforward(self,x,h):new_h=torch.tanh(x+h)returnnew_h,new_hmy_cell=MyCell()x=torch.rand(3,4)h=torch.rand(3,4)print(my_cell(x,h))
(tensor([[0.8219, 0.8990, 0.6670, 0.8277], [0.5176, 0.4017, 0.8545, 0.7336], [0.6013, 0.6992, 0.2618, 0.6668]]), tensor([[0.8219, 0.8990, 0.6670, 0.8277], [0.5176, 0.4017, 0.8545, 0.7336], [0.6013, 0.6992, 0.2618, 0.6668]]))
So we’ve:
Created a class that subclasses
torch.nn.Module
.Defined a constructor. The constructor doesn’t do much, just calls the constructor for
super
.Defined a
forward
function, which takes two inputs and returns two outputs. The actual contents of theforward
function are not really important, but it’s sort of a fake RNN cell–that is–it’s a function that is applied on a loop.
We instantiated the module, and made x
and h
, which are just 3x4 matrices of random values. Then we invoked the cell with my_cell(x,h)
. This in turn calls our forward
function.
Let’s do something a little more interesting:
classMyCell(torch.nn.Module):def__init__(self):super(MyCell,self).__init__()self.linear=torch.nn.Linear(4,4)defforward(self,x,h):new_h=torch.tanh(self.linear(x)+h)returnnew_h,new_hmy_cell=MyCell()print(my_cell)print(my_cell(x,h))
MyCell( (linear): Linear(in_features=4, out_features=4, bias=True) ) (tensor([[ 0.8573, 0.6190, 0.5774, 0.7869], [ 0.3326, 0.0530, 0.0702, 0.8114], [ 0.7818, -0.0506, 0.4039, 0.7967]], grad_fn=<TanhBackward0>), tensor([[ 0.8573, 0.6190, 0.5774, 0.7869], [ 0.3326, 0.0530, 0.0702, 0.8114], [ 0.7818, -0.0506, 0.4039, 0.7967]], grad_fn=<TanhBackward0>))
We’ve redefined our module MyCell
, but this time we’ve added a self.linear
attribute, and we invoke self.linear
in the forward function.
What exactly is happening here? torch.nn.Linear
is a Module
from the PyTorch standard library. Just like MyCell
, it can be invoked using the call syntax. We are building a hierarchy of Module
s.
print
on a Module
will give a visual representation of the Module
’s subclass hierarchy. In our example, we can see our Linear
subclass and its parameters.
By composing Module
s in this way, we can succinctly and readably author models with reusable components.
You may have noticed grad_fn
on the outputs. This is a detail of PyTorch’s method of automatic differentiation, called autograd. In short, this system allows us to compute derivatives through potentially complex programs. The design allows for a massive amount of flexibility in model authoring.
Now let’s examine said flexibility:
classMyDecisionGate(torch.nn.Module):defforward(self,x):ifx.sum()>0:returnxelse:return-xclassMyCell(torch.nn.Module):def__init__(self):super(MyCell,self).__init__()self.dg=MyDecisionGate()self.linear=torch.nn.Linear(4,4)defforward(self,x,h):new_h=torch.tanh(self.dg(self.linear(x))+h)returnnew_h,new_hmy_cell=MyCell()print(my_cell)print(my_cell(x,h))
MyCell( (dg): MyDecisionGate() (linear): Linear(in_features=4, out_features=4, bias=True) ) (tensor([[ 0.8346, 0.5931, 0.2097, 0.8232], [ 0.2340, -0.1254, 0.2679, 0.8064], [ 0.6231, 0.1494, -0.3110, 0.7865]], grad_fn=<TanhBackward0>), tensor([[ 0.8346, 0.5931, 0.2097, 0.8232], [ 0.2340, -0.1254, 0.2679, 0.8064], [ 0.6231, 0.1494, -0.3110, 0.7865]], grad_fn=<TanhBackward0>))
We’ve once again redefined our MyCell
class, but here we’ve defined MyDecisionGate
. This module utilizes control flow. Control flow consists of things like loops and if
-statements.
Many frameworks take the approach of computing symbolic derivatives given a full program representation. However, in PyTorch, we use a gradient tape. We record operations as they occur, and replay them backwards in computing derivatives. In this way, the framework does not have to explicitly define derivatives for all constructs in the language.

How autograd works¶
Basics of TorchScript¶
Now let’s take our running example and see how we can apply TorchScript.
In short, TorchScript provides tools to capture the definition of your model, even in light of the flexible and dynamic nature of PyTorch. Let’s begin by examining what we call tracing.
Tracing Modules
¶
classMyCell(torch.nn.Module):def__init__(self):super(MyCell,self).__init__()self.linear=torch.nn.Linear(4,4)defforward(self,x,h):new_h=torch.tanh(self.linear(x)+h)returnnew_h,new_hmy_cell=MyCell()x,h=torch.rand(3,4),torch.rand(3,4)traced_cell=torch.jit.trace(my_cell,(x,h))print(traced_cell)traced_cell(x,h)
MyCell( original_name=MyCell (linear): Linear(original_name=Linear) ) (tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>))
We’ve rewinded a bit and taken the second version of our MyCell
class. As before, we’ve instantiated it, but this time, we’ve called torch.jit.trace
, passed in the Module
, and passed in example inputs the network might see.
What exactly has this done? It has invoked the Module
, recorded the operations that occurred when the Module
was run, and created an instance of torch.jit.ScriptModule
(of which TracedModule
is an instance)
TorchScript records its definitions in an Intermediate Representation (or IR), commonly referred to in Deep learning as a graph. We can examine the graph with the .graph
property:
print(traced_cell.graph)
graph(%self.1 : __torch__.MyCell, %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu), %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)): %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1) %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x) %11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:191:0 %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:191:0 %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:191:0 %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13) return (%14)
However, this is a very low-level representation and most of the information contained in the graph is not useful for end users. Instead, we can use the .code
property to give a Python-syntax interpretation of the code:
print(traced_cell.code)
def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: linear = self.linear _0 = torch.tanh(torch.add((linear).forward(x, ), h)) return (_0, _0)
So why did we do all this? There are several reasons:
TorchScript code can be invoked in its own interpreter, which is basically a restricted Python interpreter. This interpreter does not acquire the Global Interpreter Lock, and so many requests can be processed on the same instance simultaneously.
This format allows us to save the whole model to disk and load it into another environment, such as in a server written in a language other than Python
TorchScript gives us a representation in which we can do compiler optimizations on the code to provide more efficient execution
TorchScript allows us to interface with many backend/device runtimes that require a broader view of the program than individual operators.
We can see that invoking traced_cell
produces the same results as the Python module:
(tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>)) (tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541, 0.2460, 0.2297, 0.1014], [-0.2329, -0.2911, 0.5641, 0.5015], [ 0.1688, 0.2252, 0.7251, 0.2530]], grad_fn=<TanhBackward0>))
Using Scripting to Convert Modules¶
There’s a reason we used version two of our module, and not the one with the control-flow-laden submodule. Let’s examine that now:
classMyDecisionGate(torch.nn.Module):defforward(self,x):ifx.sum()>0:returnxelse:return-xclassMyCell(torch.nn.Module):def__init__(self,dg):super(MyCell,self).__init__()self.dg=dgself.linear=torch.nn.Linear(4,4)defforward(self,x,h):new_h=torch.tanh(self.dg(self.linear(x))+h)returnnew_h,new_hmy_cell=MyCell(MyDecisionGate())traced_cell=torch.jit.trace(my_cell,(x,h))print(traced_cell.dg.code)print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:263: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! def forward(self, argument_1: Tensor) -> NoneType: return None def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: dg = self.dg linear = self.linear _0 = (linear).forward(x, ) _1 = (dg).forward(_0, ) _2 = torch.tanh(torch.add(_0, h)) return (_2, _2)
Looking at the .code
output, we can see that the if-else
branch is nowhere to be found! Why? Tracing does exactly what we said it would: run the code, record the operations that happen and construct a ScriptModule
that does exactly that. Unfortunately, things like control flow are erased.
How can we faithfully represent this module in TorchScript? We provide a script compiler, which does direct analysis of your Python source code to transform it into TorchScript. Let’s convert MyDecisionGate
using the script compiler:
scripted_gate=torch.jit.script(MyDecisionGate())my_cell=MyCell(scripted_gate)scripted_cell=torch.jit.script(my_cell)print(scripted_gate.code)print(scripted_cell.code)
def forward(self, x: Tensor) -> Tensor: if bool(torch.gt(torch.sum(x), 0)): _0 = x else: _0 = torch.neg(x) return _0 def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]: dg = self.dg linear = self.linear _0 = torch.add((dg).forward((linear).forward(x, ), ), h) new_h = torch.tanh(_0) return (new_h, new_h)
Hooray! We’ve now faithfully captured the behavior of our program in TorchScript. Let’s now try running the program:
# New inputsx,h=torch.rand(3,4),torch.rand(3,4)print(scripted_cell(x,h))
(tensor([[ 0.5679, 0.5762, 0.2506, -0.0734], [ 0.5228, 0.7122, 0.6985, -0.0656], [ 0.6187, 0.4487, 0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679, 0.5762, 0.2506, -0.0734], [ 0.5228, 0.7122, 0.6985, -0.0656], [ 0.6187, 0.4487, 0.7456, -0.0238]], grad_fn=<TanhBackward0>))
Mixing Scripting and Tracing¶
Some situations call for using tracing rather than scripting (e.g. a module has many architectural decisions that are made based on constant Python values that we would like to not appear in TorchScript). In this case, scripting can be composed with tracing: torch.jit.script
will inline the code for a traced module, and tracing will inline the code for a scripted module.
An example of the first case:
classMyRNNLoop(torch.nn.Module):def__init__(self):super(MyRNNLoop,self).__init__()self.cell=torch.jit.trace(MyCell(scripted_gate),(x,h))defforward(self,xs):h,y=torch.zeros(3,4),torch.zeros(3,4)foriinrange(xs.size(0)):y,h=self.cell(xs[i],h)returny,hrnn_loop=torch.jit.script(MyRNNLoop())print(rnn_loop.code)
def forward(self, xs: Tensor) -> Tuple[Tensor, Tensor]: h = torch.zeros([3, 4]) y = torch.zeros([3, 4]) y0 = y h0 = h for i in range(torch.size(xs, 0)): cell = self.cell _0 = (cell).forward(torch.select(xs, 0, i), h0, ) y1, h1, = _0 y0, h0 = y1, h1 return (y0, h0)
And an example of the second case:
classWrapRNN(torch.nn.Module):def__init__(self):super(WrapRNN,self).__init__()self.loop=torch.jit.script(MyRNNLoop())defforward(self,xs):y,h=self.loop(xs)returntorch.relu(y)traced=torch.jit.trace(WrapRNN(),(torch.rand(10,3,4)))print(traced.code)
def forward(self, xs: Tensor) -> Tensor: loop = self.loop _0, y, = (loop).forward(xs, ) return torch.relu(y)
This way, scripting and tracing can be used when the situation calls for each of them and used together.
Saving and Loading models¶
We provide APIs to save and load TorchScript modules to/from disk in an archive format. This format includes code, parameters, attributes, and debug information, meaning that the archive is a freestanding representation of the model that can be loaded in an entirely separate process. Let’s save and load our wrapped RNN module:
traced.save('wrapped_rnn.pt')loaded=torch.jit.load('wrapped_rnn.pt')print(loaded)print(loaded.code)
RecursiveScriptModule( original_name=WrapRNN (loop): RecursiveScriptModule( original_name=MyRNNLoop (cell): RecursiveScriptModule( original_name=MyCell (dg): RecursiveScriptModule(original_name=MyDecisionGate) (linear): RecursiveScriptModule(original_name=Linear) ) ) ) def forward(self, xs: Tensor) -> Tensor: loop = self.loop _0, y, = (loop).forward(xs, ) return torch.relu(y)
As you can see, serialization preserves the module hierarchy and the code we’ve been examining throughout. The model can also be loaded, for example, into C++ for python-free execution.
Further Reading¶
We’ve completed our tutorial! For a more involved demonstration, check out the NeurIPS demo for converting machine translation models using TorchScript: https://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ
Total running time of the script: ( 0 minutes 0.147 seconds)