More doc edits (#19929)

Summary:
* document `torch.jit.Attribute`
* add JIT one-liner to `README.md`
* misc clarity edits](https://our.intern.facebook.com/intern/diff/15152418/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19929

Pulled By: driazati

Differential Revision: D15152418

fbshipit-source-id: dfee03f0a17300aaf453fcf17f418463288f66c2
This commit is contained in:
davidriazati
2019-04-30 13:44:58 -07:00
committed by Facebook Github Bot
parent a9c189ca14
commit 947fd9c3f5
7 changed files with 286 additions and 131 deletions

View File

@ -37,11 +37,12 @@ At a granular level, PyTorch is a library that consists of the following compone
| Component | Description |
| ---- | --- |
| **torch** | a Tensor library like NumPy, with strong GPU support |
| **torch.autograd** | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
| **torch.nn** | a neural networks library deeply integrated with autograd designed for maximum flexibility |
| **torch.multiprocessing** | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
| **torch.utils** | DataLoader and other utility functions for convenience |
| [**torch**](https://pytorch.org/docs/stable/torch.html) | a Tensor library like NumPy, with strong GPU support |
| [**torch.autograd**](https://pytorch.org/docs/stable/autograd.html) | a tape-based automatic differentiation library that supports all differentiable Tensor operations in torch |
| [**torch.jit**](https://pytorch.org/docs/stable/jit.html) | a compilation stack (TorchScript) to create serializable and optimizable models from PyTorch code |
| [**torch.nn**](https://pytorch.org/docs/stable/nn.html) | a neural networks library deeply integrated with autograd designed for maximum flexibility |
| [**torch.multiprocessing**](https://pytorch.org/docs/stable/multiprocessing.html) | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training |
| [**torch.utils**](https://pytorch.org/docs/stable/data.html) | DataLoader and other utility functions for convenience |
Usually one uses PyTorch either as:

View File

@ -63,7 +63,7 @@ Build + CI
- Jesse Hellemn (`pjh5 <https://github.com/pjh5>`__)
- Soumith Chintala (`soumith <https://github.com/soumith>`__)
- (sunsetting) Orion Reblitz-Richardson
(`orionr <https://github.com/orionr>`__)
(`orionr <https://github.com/orionr>`__)
Distributions & RNG
~~~~~~~~~~~~~~~~~~~

View File

@ -258,7 +258,7 @@ Probability distributions - torch.distributions
:show-inheritance:
:hidden:`LogitRelaxedBernoulli`
~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torch.distributions.relaxed_bernoulli
.. autoclass:: LogitRelaxedBernoulli
@ -301,7 +301,7 @@ Probability distributions - torch.distributions
:members:
:undoc-members:
:show-inheritance:
:hidden:`Weibull`
~~~~~~~~~~~~~~~~~~~~~~~

View File

@ -151,7 +151,7 @@ net models. In particular, TorchScript supports:
Unlike Python, each variable in TorchScript function must have a single static type.
This makes it easier to optimize TorchScript functions.
Example::
Example (a type mismatch)::
@torch.jit.script
def an_error(x):
@ -201,35 +201,34 @@ Example::
@torch.jit.script_method
def forward(self, x):
# type: (Tensor) -> Tuple[List[Tuple[Tensor, Tensor]], Dict[int, Tensor]]
# type: (Tensor) -> Tuple[List[Tuple[int, float]], Dict[str, int]]
# This annotates the list to be a `List[Tuple[Tensor, Tensor]]`
list_of_tuple = torch.jit.annotate(List[Tuple[Tensor, Tensor]], [])
# This annotates the list to be a `List[Tuple[int, float]]`
my_list = torch.jit.annotate(List[Tuple[int, float]], [])
for i in range(10):
list_of_tuple.append((x, x))
my_list.append((x, x))
# This annotates the list to be a `Dict[int, Tensor]`
int_tensor_dict = torch.jit.annotate(Dict[int, Tensor], {})
return list_of_tuple, int_tensor_dict
my_dict = torch.jit.annotate(Dict[str, int], {})
return my_list, my_dict
Optional Type Refinement
^^^^^^^^^^^^^^^^^^^^^^^^
TorchScript will refine the type of a variable of type Optional[T] when
a comparison to None is made inside the conditional of an if statement.
The compiler can reason about multiple None checks that are combined with
AND, OR, or NOT. Refinement will also occur for else blocks of if statements
TorchScript will refine the type of a variable of type ``Optional[T]`` when
a comparison to ``None`` is made inside the conditional of an if-statement.
The compiler can reason about multiple ``None`` checks that are combined with
``and``, ``or``, and ``not``. Refinement will also occur for else blocks of if-statements
that are not explicitly written.
The expression must be emitted within the conditional; assigning
a None check to a variable and using it in the conditional will not refine types.
a ``None`` check to a variable and using it in the conditional will not refine types.
Example::
@torch.jit.script
def opt_unwrap(x, y, z):
def optional_unwrap(x, y, z):
# type: (Optional[int], Optional[int], Optional[int]) -> int
if x is None:
x = 1
@ -240,6 +239,66 @@ Example::
return x
User Defined Types
^^^^^^^^^^^^^^^^^^^^^^^^
Python classes can be used in TorchScript if they are annotated with ``@torch.jit.script``,
similar to how you would declare a TorchScript function: ::
@torch.jit.script
class Foo:
def __init__(self, x, y)
self.x = x
def aug_add_x(self, inc):
self.x += inc
This subset is restricted:
* All functions must be valid TorchScript functions (including ``__init__()``)
* Classes must be new-style classes, as we use ``__new__()`` to construct them with pybind11
* TorchScript classes are statically typed. Members are declared by assigning to
self in the ``__init__()`` method
For example, assigning outside of the ``__init__()`` method: ::
@torch.jit.script
class Foo:
def assign_x(self):
self.x = torch.rand(2, 3)
Will result in: ::
RuntimeError:
Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?:
def assign_x(self):
self.x = torch.rand(2, 3)
~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
* No expressions except method definitions are allowed in the body of the class
* No support for inheritance or any other polymorphism strategy, except for inheriting
from object to specify a new-style class
After a class is defined, it can be used in both TorchScript and Python interchangeably
like any other TorchScript type:
::
@torch.jit.script
class Pair:
def __init__(self, first, second)
self.first = first
self.second = second
@torch.jit.script
def sum_pair(p):
# type : (Pair) -> Tensor
return p.first + p.second
p = Pair(torch.rand(2, 3), torch.rand(2, 3)
print(sum_pair(p))
Expressions
~~~~~~~~~~~
@ -255,8 +314,9 @@ List Construction
``[3, 4]``, ``[]``, ``[torch.rand(3), torch.rand(4)]``
.. note::
an empty list is assumed have type ``List[Tensor]``.
An empty list is assumed have type ``List[Tensor]``.
The types of other list literals are derived from the type of the members.
To denote an empty list of another type, use ``torch.jit.annotate``.
Tuple Construction
""""""""""""""""""
@ -268,8 +328,9 @@ Dict Construction
``{'hello': 3}``, ``{}``, ``{'a': torch.rand(3), 'b': torch.rand(4)}``
.. note::
an empty dict is assumed have type ``Dict[str, Tensor]``.
An empty dict is assumed have type ``Dict[str, Tensor]``.
The types of other dict literals are derived from the type of the members.
To denote an empty dict of another type, use ``torch.jit.annotate``.
Variables
^^^^^^^^^
@ -341,10 +402,6 @@ Subscripts
``t[i:j, i]``
.. note::
TorchScript currently does not support mutating tensors in place, so any
tensor indexing can only appear on the right-hand size of an expression.
Function Calls
^^^^^^^^^^^^^^
Calls to built-in functions: ``torch.rand(3, dtype=torch.int)``
@ -468,11 +525,6 @@ For loops with ``range``
for i in range(10):
x *= i
.. note::
Script currently does not support iterating over generic iterable
objects like lists or tensors. Script currently does not support start or
increment parameters to range. These will be added in a future version.
For loops over tuples:
::
@ -512,9 +564,9 @@ For loops over constant ``torch.nn.ModuleList``
return v
.. note::
To use a module list inside a ``@script_method`` it must be marked
To use a ``nn.ModuleList`` inside a ``@script_method`` it must be marked
constant by adding the name of the attribute to the ``__constants__``
list for the type. For loops over a ModuleList will unroll the body of the
list for the type. For loops over a ``nn.ModuleList`` will unroll the body of the
loop at compile time, with each member of the constant module list.
Return
@ -557,17 +609,17 @@ To make writing TorchScript more convenient, we allow script code to refer
to Python values in the surrounding scope. For instance, any time there is a
reference to ``torch``, the TorchScript compiler is actually resolving it to the
``torch`` Python module when the function is declared. These Python values are
not a first class part of TorchScript. Instead they are desugared at compile-time
into the primitive types that TorchScript supports. This section describes the
rules that are used when accessing Python values in TorchScript. They depend
on the dynamic type of the python valued referenced.
not a first class part of TorchScript. Instead they are de-sugared at compile-time
into the primitive types that TorchScript supports. This depends
on the dynamic type of the Python valued referenced when compilation occurs.
This section describes the rules that are used when accessing Python values in TorchScript.
Functions
^^^^^^^^^
TorchScript can call Python functions. This functionality is very useful when
incrementally converting a model into script. The model can be moved function-by-function
to script, leaving calls to Python functions in place. This way you can incrementally
incrementally converting a model to TorchScript. The model can be moved function-by-function
to TorchScript, leaving calls to Python functions in place. This way you can incrementally
check the correctness of the model as you go.
Example::
@ -581,10 +633,37 @@ Functions
def bar(x)
return foo(x + 1)
.. note::
Attempting to call ``save`` on a ScriptModule that contains calls to Python
functions will fail. The intention is that this pathway is used for debugging
and the calls removed or turned into script functions before saving.
Attempting to call ``save`` on a ScriptModule that contains calls to Python
functions will fail. The intention is that this pathway is used for debugging
and the calls removed or turned into script functions before saving. If you
want to export a module with a Python function, add the ``@torch.jit.ignore``
decorator to the function which will replace these function calls with an
exception when the model is saved: ::
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
@torch.jit.script_method
def forward(self, x):
self.ignored_code(x)
return x + 2
@torch.jit.ignore
def ignored_code(self, x):
# non-TorchScript code
import pdb; pdb.set_trace()
m = M()
# Runs, makes upcall to Python to run `ignored_code`
m(torch.ones(2, 2))
# Replaces all calls to `ignored_code` with a `raise`
m.save("m.pt")
loaded = torch.jit.load("m.pt")
# This runs `ignored_code` after saving which will raise an Exception!
loaded(torch.ones(2, 2))
Attribute Lookup On Python Modules
@ -621,6 +700,7 @@ Python-defined Constants
Supported constant Python Values are
* ``int``
* ``float``
* ``bool``
* ``torch.device``
* ``torch.layout``
@ -629,6 +709,31 @@ Python-defined Constants
* ``torch.nn.ModuleList`` which can be used in a TorchScript for loop
Module Attributes
^^^^^^^^^^^^^^^^^
The ``torch.nn.Parameter`` wrapper and ``register_buffer`` can be used to assign
tensors to a ``ScriptModule``. In a similar vein, attributes of any type can be
assign on a ``ScriptModule`` by wrapping them with ``torch.jit.Attribute`` and
specifying the type. All types available in TorchScript are supported. These
attributes are mutable and are saved in a separate archive in the serialized
model binary. Tensor attributes are semantically the same as buffers.
Example::
class Foo(torch.jit.ScriptModule):
def __init__(self, a_dict):
super(Foo, self).__init__(False)
self.words = torch.jit.Attribute([], List[str])
self.some_dict = torch.jit.Attribute(a_dict, Dict[str, int])
@torch.jit.script_method
def forward(self, input):
# type: (str) -> int
self.words.append(input)
return self.some_dict[input]
Debugging
~~~~~~~~~
@ -655,21 +760,21 @@ Disable JIT for Debugging
traced_fn(torch.rand(3, 4))
Debugging this script with PDB works except for when we invoke the @script
function. We can globally disable JIT, so that we can call the @script
Debugging this script with PDB works except for when we invoke the ``@torch.jit.script``
function. We can globally disable JIT, so that we can call the ``@torch.jit.script``
function as a normal python function and not compile it. If the above script
is called ``disable_jit_example.py``, we can invoke it like so::
$ PYTORCH_JIT=0 python disable_jit_example.py
and we will be able to step into the @script function as a normal Python
and we will be able to step into the ``@torch.jit.script`` function as a normal Python
function.
Inspecting Code
^^^^^^^^^^^^^^^
TorchScript provides a code pretty-printer for all ScriptModule instances. This
TorchScript provides a code pretty-printer for all ``ScriptModule`` instances. This
pretty-printer gives an interpretation of the script method's code as valid
Python syntax. For example::
@ -688,11 +793,11 @@ Inspecting Code
A ``ScriptModule`` with a single ``forward`` method will have an attribute
``code``, which you can use to inspect the ``ScriptModule``'s code.
If the ScriptModule has more than one method, you will need to access
If the ``ScriptModule`` has more than one method, you will need to access
``.code`` on the method itself and not the module. We can inspect the
code of a method named ``bar`` on a ScriptModule by accessing ``.bar.code``.
The example script abouve produces the code::
The example script above produces the code::
def forward(self,
len: int) -> Tensor:
@ -706,7 +811,7 @@ Inspecting Code
rv0 = rv1
return rv0
This is TorchScript's interpretation of the code for the ``forward`` method.
This is TorchScript's compilation of the code for the ``forward`` method.
You can use this to ensure TorchScript (tracing or scripting) has captured
your model code correctly.
@ -734,7 +839,7 @@ Interpreting Graphs
print(foo.graph)
``.graph`` follows the same rules described in the Inspecting Code section
``.graph`` follows the same rules described in the `Inspecting Code`_ section
with regard to ``forward`` method lookup.
The example script above produces the graph::
@ -949,9 +1054,9 @@ best practices?
# ... later, when using the model:
if use_gpu:
model = torch.jit.load("gpu.pth")
model = torch.jit.load("gpu.pth")
else:
model = torch.jit.load("cpu.pth")
model = torch.jit.load("cpu.pth")
model(input)
@ -961,6 +1066,40 @@ best practices?
the correct device information.
Q: How do I store attributes on a ``ScriptModule``?
Say we have a model like: ::
class Model(torch.jit.ScriptModule):
def __init__(self):
super(Model, self).__init__()
self.x = 2
@torch.jit.script_method
def forward(self):
return self.x
If ``Model`` is instantiated it will result in a compilation error
since the compiler doesn't know about ``x``. There are 4 ways to inform the
compiler of attributes on ``ScriptModule``:
1. ``nn.Parameter`` - values wrapped in ``nn.Parameter`` will work as they
do on ``nn.Module``\s
2. ``register_buffer`` - values wrapped in ``register_buffer`` will work as
they do on ``nn.Module``\s
3. ``__constants__`` - adding a list called ``__constants__`` at the
class definition level will mark the contained names as constants. Constants
are saved directly in the code of the model. See
`Python-defined Constants`_.
4. ``torch.jit.Attribute`` - values wrapped in ``torch.jit.Attribute`` can
be any ``TorchScript`` type, be mutated and are saved outside of the code of
the model. See `Module Attributes`_.
Builtin Functions
~~~~~~~~~~~~~~~~~

View File

@ -530,7 +530,7 @@ Linear layers
----------------------------------
:hidden:`Identity`
~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~
.. autoclass:: Identity
:members:

View File

@ -1290,20 +1290,17 @@ During export a list of all the tensors in a model is created. Tensors can come
### `attributes.pkl`
Attributes are all module properties that are not parameters or constants. Attributes are saved in a list in the order they were defined on the module. The list is stored as a Python `pickle` archive. `pickle`'s format was chosen due to:
* **user friendliness** - the attributes file can be loaded in Python with `pickle` without having PyTorch installed
* **size limits** - formats such as Protobuf empose size limits on total message size, whereas pickle limits are on individual values (e.g. strings cannot be longer than 4 GB)
* **standard format** - `pickle` is a standard Python module with a reasonably simple format. The format is a program to be consumed by a stack machine that is detailed in Python's [`pickletools.py`](https://svn.python.org/projects/python/trunk/Lib/pickletools.py)
* **built-in memoization** - for shared reference types (e.g. Tensor, string, lists, dicts)
* **self describing** - a separate definition file is not needed to understand the pickled data
* **eager mode save** - `torch.save()` already produces a `pickle` archive, so doing the same with attributes may ease unification of these formats in the future
[pickler.h](pickler.h),
[pickler.cpp](pickler.cpp),
[torch/jit/_pickle.py](../../../torch/jit/_pickle.py)
[caffe2/proto/torch.proto](../../../caffe2/proto/torch.proto)
A given module may have many attributes of different types and many submodules, each with their own attributes. Attributes are recorded in `model.json`:
Attributes are all module properties that are not parameters or constants. Attributes are saved in a list in the order they were defined on the module. A given module may have many attributes of different types and many submodules, each with their own attributes. Attribute metadata is recorded in `model.json`:
* `type` - the full type of the attribute (in [Mypy syntax](https://mypy.readthedocs.io/en/latest/cheat_sheet_py3.html))
* `name` - the attribute's name
* `id` - the offset into the saved list of all model attributes
`model.json`
In `model.json`:
```json
{
"mainModule": {
@ -1344,41 +1341,61 @@ A given module may have many attributes of different types and many submodules,
}
```
Attributes of the main module and its submodules are saved to a single file in the `zip` archive of a `.pt` file named `attributes.pkl`. A single file is used so that attributes can reference each other and shared values. Unpickling this will return a list of values corresponding to the attributes.
Attributes of the main module and its submodules are saved to a single file in the `zip` archive of a `.pt` file named `attributes.pkl`. Attributes are stored as a Python `pickle` archive. `pickle`'s format was chosen due to:
* **user friendliness** - the attributes file can be loaded in Python with `pickle`
* **size limits** - formats such as Protobuf empose size limits on total message size, whereas pickle limits are on individual values (e.g. strings cannot be longer than 4 GB)
* **standard format** - `pickle` is a standard Python module with a reasonably simple format. The format is a program to be consumed by a stack machine that is detailed in Python's [`pickletools.py`](https://svn.python.org/projects/python/trunk/Lib/pickletools.py)
* **built-in memoization** - for shared reference types (e.g. Tensor, string, lists, dicts)
* **self describing** - a separate definition file is not needed to understand the pickled data
* **eager mode save** - `torch.save()` already produces a `pickle` archive, so doing the same with attributes avoids introducing yet another format
All attributes are written into the `attributes.pkl` file with the exception of tensors, which store only a tensor table index (see "tensors" above). Classes are used to mark special data types, such as this tensor table index or specialized lists. To load the `attributes.pkl` file without PyTorch for inspection or manual editing, these classes must be defined, so a custom [`Unpickler`](https://docs.python.org/3/library/pickle.html#pickle.Unpickler) is necessary:
[pickler.cpp](pickler.cpp) implements a subset of the Pickle format necessary for TorchScript models.
A single file is used for the top level module and all submodules so that attributes can reference each other and share values. Unpickling `attributes.pkl` will return a tuple of values corresponding to the attributes.
All attributes are written into the `attributes.pkl` file with the exception of tensors, which store only a tensor table index (see "tensors" above). PyTorch functions defined in [torch/jit/_pickle.py](../../../torch/jit/_pickle.py) are used to mark special data types, such as this tensor table index or specialized lists. To load the `attributes.pkl` file, use the `pickle` module in Python:
```python
import pickle
# attributes.pkl include references to functions in torch.jit._pickle
import torch
pickle.load(open("attributes.pkl", "rb"))
```
If for some reason you don't have PyTorch installed, you can still load `attributes.pkl` with a custom [`Unpickler`](https://docs.python.org/3/library/pickle.html#pickle.Unpickler):
```python
import pickle
# Tensor objects are stored as instances of this class
class TensorID(object):
def __setstate__(self, id):
self.id = id
# List[int] has internal specializations, and these are indicated with this class
class IntList(object):
def __setstate__(self, data):
self.data = data
class JitUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if not module == '__main__':
return None
if module != 'torch.jit._pickle':
raise RuntimeError("Unknown module")
if name == 'TensorID':
return TensorID
elif name == 'IntList':
return IntList
identity = lambda x: x
if name == 'build_tensor_from_id':
# Without the tensor table we can't do anything other than
# return the tensor ID
return identity
elif name == 'build_intlist':
return identity
JitUnpickler(open("my_model/attributes.pkl", "rb")).load()
print(JitUnpickler(open("out_dir/out/attributes.pkl", "rb")).load())
```
#### Binary Format
Running the following snippet produces a `ScriptModule` with several attributes.
Python's `pickletools` module can be used to decode the binary blob of `attributes.pkl` into a human readable format.
```python
import pickletools
import zipfile
import torch
from typing import Tuple, List
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
@ -1391,50 +1408,46 @@ class M(torch.jit.ScriptModule):
def forward(self):
return (self.float, self.tuple, self.tensor, self.int_list)
M().save("out.pt")
M().save("out.zip")
model_zip = zipfile.ZipFile("out.zip", 'r')
model_zip.extractall("out_dir")
pickletools.dis(open("out_dir/out/attributes.pkl", "rb"))
```
In a terminal, Python's `pickletools` module can be used to decode the binary blob of `attributes.pkl` into a human readable format.
```bash
unzip -o out.pt
python -m pickletools out/attributes.pkl
The output of the above commands demonstrates the concepts described earlier. Attributes are wrapped in with `2: EMPTY_LIST` and appear in the order they are defined on the module. Functions for certain special types (e.g. `List[int]`, `Tensor`) can be seen at `37: GLOBAL` and `66: GLOBAL`, followed by data specific to that type, then finally by an instruction to build the object at `65: BUILD` and `113: BUILD` respectively.
```
The output of the above commands demonstrates the concepts described earlier. Attributes are wrapped in with `2: EMPTY_LIST` and appear in the order they are defined on the module. Classes for certain special types (`List[int]`, `Tensor`) can be seen at `37: GLOBAL` and `66: GLOBAL`, followed by data specific to that type, then finally by an instruction to build the object at `65: BUILD` and `113: BUILD` respectively.
```
0: \x80 PROTO 2
2: ] EMPTY_LIST
3: ( MARK
4: G BINFLOAT 2.3
13: ( MARK
14: J BININT 1
19: J BININT 2
24: J BININT 3
29: J BININT 4
34: t TUPLE (MARK at 13)
35: q BINPUT 0
37: c GLOBAL '__main__ TensorID'
56: q BINPUT 1
58: ) EMPTY_TUPLE
59: \x81 NEWOBJ
60: J BININT 0
65: b BUILD
66: c GLOBAL '__main__ IntList'
84: q BINPUT 2
86: ) EMPTY_TUPLE
87: \x81 NEWOBJ
88: ] EMPTY_LIST
89: q BINPUT 3
91: ( MARK
92: J BININT 1
97: J BININT 2
102: J BININT 3
107: J BININT 4
112: e APPENDS (MARK at 91)
113: b BUILD
114: e APPENDS (MARK at 3)
115: . STOP
0: \x80 PROTO 2
2: ( MARK
3: G BINFLOAT 2.3
12: ( MARK
13: K BININT1 1
15: K BININT1 2
17: K BININT1 3
19: K BININT1 4
21: t TUPLE (MARK at 12)
22: q BINPUT 0
24: c GLOBAL 'torch.jit._pickle build_tensor_from_id'
64: q BINPUT 1
66: ( MARK
67: K BININT1 0
69: t TUPLE (MARK at 66)
70: R REDUCE
71: c GLOBAL 'torch.jit._pickle build_intlist'
104: q BINPUT 2
106: ( MARK
107: ] EMPTY_LIST
108: ( MARK
109: K BININT1 1
111: K BININT1 2
113: K BININT1 3
115: K BININT1 4
117: e APPENDS (MARK at 108)
118: t TUPLE (MARK at 106)
119: R REDUCE
120: q BINPUT 3
122: t TUPLE (MARK at 2)
123: . STOP
highest protocol among opcodes = 2
```

View File

@ -98,7 +98,8 @@ def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP):
Returns:
A ``ScriptModule`` object.
Example:
Example: ::
torch.jit.load('scriptmodule.pt')
# Load ScriptModule from io.BytesIO object
@ -177,7 +178,8 @@ def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP):
Please use something like ``io.BytesIO`` instead.
Example:
Example: ::
m = torch.jit.ScriptModule()
# Save to file
@ -1068,13 +1070,13 @@ if _enabled:
The core data structure in TorchScript is the ``ScriptModule``. It is an
analogue of torch's ``nn.Module`` and represents an entire model as a tree of
submodules. Like normal modules, each individual module in a ``ScriptModule`` can
have submodules, parameters, and methods. In ``nn.Module``s methods are implemented
as Python functions, but in ``ScriptModule``s methods are implemented as
have submodules, parameters, and methods. In ``nn.Module``\s methods are implemented
as Python functions, but in ``ScriptModule``\s methods are implemented as
TorchScript functions, a statically-typed subset of Python that contains all
of PyTorch's built-in Tensor operations. This difference allows your
ScriptModules code to run without the need for a Python interpreter.
``ScriptModule``s be created in two ways:
``ScriptModule``\s be created in two ways:
**Tracing:**
@ -1131,9 +1133,9 @@ if _enabled:
You can write TorchScript code directly using Python syntax. You do this
using the ``@torch.jit.script`` decorator (for functions) or
``@torch.jit.script_method`` decorator (for methods) on subclasses of
ScriptModule. With this decorator the body of the annotated function is
``ScriptModule``. With this decorator the body of the annotated function is
directly translated into TorchScript. TorchScript itself is a subset of
the Python language, so not all features in python work, but we provide
the Python language, so not all features in Python work, but we provide
enough functionality to compute on tensors and do control-dependent
operations.