mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add export_for_training as public API (#134677)
Differential Revision: [D61912084](https://our.internmc.facebook.com/intern/diff/D61912084) Pull Request resolved: https://github.com/pytorch/pytorch/pull/134677 Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
a7933acd5a
commit
6dd3f81aaf
@ -1437,7 +1437,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
x_linear = self.linear(x_conv)
|
||||
return x_linear.cos() + y_conv_1d.sum()
|
||||
|
||||
ep = torch.export._trace._export_for_training(
|
||||
ep = torch.export.export_for_training(
|
||||
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
|
||||
)
|
||||
ep_has_linear_convd = ep.run_decompositions(
|
||||
@ -1570,7 +1570,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_
|
||||
return self.linear(x)
|
||||
|
||||
eager_model = Foo()
|
||||
ep_for_training = torch.export._trace._export_for_training(
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
eager_model, (torch.ones(2, 2),)
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
@ -1609,7 +1609,7 @@ def forward(self, x):
|
||||
|
||||
eager_model_for_export = Foo()
|
||||
eager_model_for_testing = Foo()
|
||||
ep_for_training = torch.export._trace._export_for_training(
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
eager_model_for_export, (torch.ones(4, 4),)
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
@ -1654,7 +1654,7 @@ def forward(self, x):
|
||||
eager_model_for_export_training = Foo()
|
||||
eager_model_for_export_inference = Foo()
|
||||
eager_model_for_testing = Foo()
|
||||
ep_for_training = torch.export._trace._export_for_training(
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
eager_model_for_export_training,
|
||||
(torch.ones(4, 4),),
|
||||
dynamic_shapes=({0: Dim("x")},),
|
||||
@ -1691,7 +1691,7 @@ def forward(self, x):
|
||||
return x + y + self.buffer.sum()
|
||||
|
||||
eager_model = Foo()
|
||||
ep_for_training = torch.export._trace._export_for_training(
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
eager_model,
|
||||
([torch.ones(4, 4), torch.ones(4, 4)],),
|
||||
)
|
||||
@ -1717,7 +1717,7 @@ def forward(self, x):
|
||||
return self.linear(x) + self.buffer.sum()
|
||||
|
||||
eager_model = Foo()
|
||||
ep_for_training = torch.export._trace._export_for_training(
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
eager_model,
|
||||
(torch.ones(2, 2),),
|
||||
)
|
||||
|
@ -1,19 +1,20 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
from . import test_export, testing
|
||||
except ImportError:
|
||||
import test_export
|
||||
import testing
|
||||
|
||||
from torch.export._trace import _export_for_training
|
||||
import testing
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
|
||||
ep = _export_for_training(*args, **kwargs)
|
||||
ep = torch.export.export_for_training(*args, **kwargs)
|
||||
return ep.run_decompositions(
|
||||
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
|
||||
)
|
||||
@ -21,9 +22,9 @@ def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
|
||||
|
||||
def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = _export_for_training(*args, **kwargs)
|
||||
ep = torch.export.export_for_training(*args, **kwargs)
|
||||
else:
|
||||
ep = _export_for_training(*args, **kwargs, strict=False)
|
||||
ep = torch.export.export_for_training(*args, **kwargs, strict=False)
|
||||
return ep.run_decompositions(
|
||||
{}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY
|
||||
)
|
||||
|
@ -65,9 +65,9 @@ def capture_pre_autograd_graph_warning():
|
||||
log.warning("| !!! WARNING !!! |")
|
||||
log.warning("+============================+")
|
||||
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
|
||||
log.warning("Please switch to use torch.export._trace._export_for_training instead.")
|
||||
log.warning("Please switch to use torch.export.export_for_training instead.")
|
||||
if config.is_fbcode():
|
||||
log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export._trace._export_for_training.") # noqa: B950
|
||||
log.warning("Unless the unittest is in the blocklist, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@ -128,9 +128,9 @@ def capture_pre_autograd_graph(
|
||||
if capture_pre_autograd_graph_using_training_ir():
|
||||
@lru_cache
|
||||
def print_export_warning():
|
||||
log.warning("Using torch.export._trace._export_for_training(...,strict=True)")
|
||||
log.warning("Using torch.export.export_for_training(...,strict=True)")
|
||||
print_export_warning()
|
||||
module = torch.export._trace._export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
|
||||
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
|
||||
else:
|
||||
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
|
||||
|
||||
|
@ -51,6 +51,7 @@ __all__ = [
|
||||
"ModuleCallSignature",
|
||||
"dims",
|
||||
"export",
|
||||
"export_for_training",
|
||||
"load",
|
||||
"register_dataclass",
|
||||
"save",
|
||||
@ -69,6 +70,91 @@ from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
|
||||
PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
||||
|
||||
|
||||
def export_for_training(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion,
|
||||
which can subsequently be executed with different inputs or serialized. The
|
||||
traced graph (1) produces normalized operators in the all ATen operator set
|
||||
(as well as any user-specified custom operators), (2) has eliminated all Python control
|
||||
flow and data structures (with certain exceptions), and (3) records the set of
|
||||
shape constraints needed to show that this normalization and control-flow elimination
|
||||
is sound for future inputs. This API is intended for PT2 quantization training use cases
|
||||
and will soon be the default IR of torch.export.export in the near future.
|
||||
|
||||
**Soundness Guarantee**
|
||||
|
||||
See :func:`export()` docstring for more details.
|
||||
|
||||
Args:
|
||||
mod: We will trace the forward method of this module.
|
||||
|
||||
args: Example positional inputs.
|
||||
|
||||
kwargs: Optional example keyword inputs.
|
||||
|
||||
dynamic_shapes:
|
||||
An optional argument where the type should either be:
|
||||
1) a dict from argument names of ``f`` to their dynamic shape specifications,
|
||||
2) a tuple that specifies dynamic shape specifications for each input in original order.
|
||||
If you are specifying dynamism on keyword args, you will need to pass them in the order that
|
||||
is defined in the original function signature.
|
||||
|
||||
The dynamic shape of a tensor argument can be specified as either
|
||||
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
|
||||
not required to include static dimension indices in this dict, but when they are,
|
||||
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
|
||||
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
|
||||
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
|
||||
recursively specified by using mappings or sequences of contained specifications.
|
||||
|
||||
strict: When enabled (default), the export function will trace the program through
|
||||
TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the
|
||||
exported program will not validate the implicit assumptions baked into the graph and
|
||||
may cause behavior divergence between the original model and the exported one. This is
|
||||
useful when users need to workaround bugs in the tracer, or simply want incrementally
|
||||
enable safety in their models. Note that this does not affect the resulting IR spec
|
||||
to be different and the model will be serialized in the same way regardless of what value
|
||||
is passed here.
|
||||
WARNING: This option is experimental and use this at your own risk.
|
||||
|
||||
Returns:
|
||||
An :class:`ExportedProgram` containing the traced callable.
|
||||
|
||||
**Acceptable input/output types**
|
||||
|
||||
Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include:
|
||||
|
||||
- Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``.
|
||||
- Dataclasses, but they must be registered by calling :func:`register_dataclass` first.
|
||||
- (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and
|
||||
``OrderedDict`` containing all above types.
|
||||
|
||||
"""
|
||||
from ._trace import _export_for_training
|
||||
|
||||
if not isinstance(mod, torch.nn.Module):
|
||||
raise ValueError(
|
||||
f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}."
|
||||
)
|
||||
return _export_for_training(
|
||||
mod,
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes,
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
)
|
||||
|
||||
|
||||
def export(
|
||||
mod: torch.nn.Module,
|
||||
args: Tuple[Any, ...],
|
||||
|
Reference in New Issue
Block a user