mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] Add TracingContext (#149294)
TracingContext is added to all tracing locations -- in torch.export this is where we call make_fx (for training IR) and aot_export_module (for inference IR), and in run_decompositions where we call aot_export_module Differential Revision: [D71298927](https://our.internmc.facebook.com/intern/diff/D71298927) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149294 Approved by: https://github.com/ydwu4
This commit is contained in:
committed by
PyTorch MergeBot
parent
a3c286677b
commit
01a57981aa
@ -63,7 +63,7 @@ from torch._functorch.aot_autograd import (
|
||||
_detect_attribute_assignment,
|
||||
aot_export_module,
|
||||
)
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._guards import detect_fake_mode, tracing, TracingContext
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._logging import dtrace_structured
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
@ -1403,7 +1403,8 @@ def _strict_export(
|
||||
if name in reverse_name_lookup
|
||||
}
|
||||
|
||||
with dynamo_fake_mode:
|
||||
tx = TracingContext(dynamo_fake_mode)
|
||||
with dynamo_fake_mode, tracing(tx):
|
||||
aten_export_artifact = _to_aten_func(
|
||||
gm_torch_level,
|
||||
# NOTE: graph module expects only positional args
|
||||
@ -1862,7 +1863,8 @@ def _non_strict_export(
|
||||
_is_torch_jit_trace=_is_torch_jit_trace,
|
||||
)
|
||||
|
||||
with fake_mode, _NonStrictTorchFunctionHandler():
|
||||
tx = TracingContext(fake_mode)
|
||||
with fake_mode, _NonStrictTorchFunctionHandler(), tracing(tx):
|
||||
with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
|
||||
patched_mod,
|
||||
new_fake_args,
|
||||
|
@ -12,6 +12,7 @@ from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._guards import tracing, TracingContext
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._subclasses.fake_impls import (
|
||||
@ -462,12 +463,16 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
else:
|
||||
retracing_args.append(node.meta["val"])
|
||||
|
||||
tx = TracingContext(fake_mode)
|
||||
|
||||
with (
|
||||
fake_mode
|
||||
), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp(
|
||||
cia_to_decomp,
|
||||
), _enable_graph_inputs_of_type_nn_module(
|
||||
ep.example_inputs
|
||||
), tracing(
|
||||
tx
|
||||
):
|
||||
retracing_args_unwrapped = pytree.tree_unflatten(
|
||||
retracing_args, mod._in_spec
|
||||
|
Reference in New Issue
Block a user