[export] Add TracingContext (#149294)

TracingContext is added to all tracing locations -- in torch.export this is where we call make_fx (for training IR) and aot_export_module (for inference IR), and in run_decompositions where we call aot_export_module

Differential Revision: [D71298927](https://our.internmc.facebook.com/intern/diff/D71298927)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149294
Approved by: https://github.com/ydwu4
This commit is contained in:
angelayi
2025-03-18 15:55:28 -07:00
committed by PyTorch MergeBot
parent a3c286677b
commit 01a57981aa
2 changed files with 10 additions and 3 deletions

View File

@ -63,7 +63,7 @@ from torch._functorch.aot_autograd import (
_detect_attribute_assignment,
aot_export_module,
)
from torch._guards import detect_fake_mode
from torch._guards import detect_fake_mode, tracing, TracingContext
from torch._library.fake_class_registry import FakeScriptObject
from torch._logging import dtrace_structured
from torch._subclasses.fake_tensor import FakeTensorMode
@ -1403,7 +1403,8 @@ def _strict_export(
if name in reverse_name_lookup
}
with dynamo_fake_mode:
tx = TracingContext(dynamo_fake_mode)
with dynamo_fake_mode, tracing(tx):
aten_export_artifact = _to_aten_func(
gm_torch_level,
# NOTE: graph module expects only positional args
@ -1862,7 +1863,8 @@ def _non_strict_export(
_is_torch_jit_trace=_is_torch_jit_trace,
)
with fake_mode, _NonStrictTorchFunctionHandler():
tx = TracingContext(fake_mode)
with fake_mode, _NonStrictTorchFunctionHandler(), tracing(tx):
with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
patched_mod,
new_fake_args,

View File

@ -12,6 +12,7 @@ from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
from torch._guards import tracing, TracingContext
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_impls import (
@ -462,12 +463,16 @@ def _decompose_and_get_gm_with_new_signature_constants(
else:
retracing_args.append(node.meta["val"])
tx = TracingContext(fake_mode)
with (
fake_mode
), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp(
cia_to_decomp,
), _enable_graph_inputs_of_type_nn_module(
ep.example_inputs
), tracing(
tx
):
retracing_args_unwrapped = pytree.tree_unflatten(
retracing_args, mod._in_spec