Compare commits

...

8 Commits

Author SHA1 Message Date
4290a4a214 api doc 2025-09-24 13:44:16 -07:00
416ca0ae6f fix test 2025-09-24 10:25:14 -07:00
0e19aac62c fix dynamo 2025-09-23 22:34:05 -07:00
65720ec2bc fix dynamo 2025-09-23 22:19:14 -07:00
0479689440 fix build 2025-09-23 19:57:13 -07:00
260f4ebffd address comment 2025-09-23 14:28:29 -07:00
e04e8c1df5 add test 2025-09-23 14:03:29 -07:00
9fb573399f preserve annotation in user code 2025-09-23 13:03:41 -07:00
5 changed files with 106 additions and 0 deletions

View File

@ -1093,6 +1093,9 @@ The set of leaf modules can be customized by overriding
```{eval-rst}
.. autofunction:: torch.fx.replace_pattern
```
```{eval-rst}
.. autofunction:: torch.fx.traceback.annotate
```
<!-- The experimental and passes submodules are missing docs. -->
<!-- Adding it here for coverage but this doesn't add anything to the -->

View File

@ -21,6 +21,7 @@ from unittest.mock import MagicMock, patch
import torch
import torch._dynamo as torchdynamo
import torch.fx.traceback as fx_traceback
import torch.nn.functional as F
import torch.utils._pytree as pytree
from functorch.experimental.control_flow import cond, map
@ -15137,6 +15138,39 @@ def forward(self, x):
test_serdes=True,
)
# TODO: following tests should be fixed
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureTrainingIRToRunDecompNonStrict
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
x = x + 1
x = x - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
x = x * 2
x = x / 3
return x
m = M()
with fx_traceback.preserve_node_meta():
ep = export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.target == torch.ops.aten.add.default:
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.mul.default:
self.assertTrue(
node.meta["custom"],
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
)
if node.target == torch.ops.aten.div.default:
self.assertTrue(node.meta["custom"], {})
def test_dynamic_shapes_serdes_generic(self):
from torch._export.serde.dynamic_shapes import (
_dump_dynamic_shapes,

View File

@ -51,6 +51,7 @@ from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
from .utils import (
getfile,
hashable,
is_annotate_wrapped_function,
is_lru_cache_wrapped_function,
NP_SUPPORTED_MODULES,
unwrap_if_wrapper,
@ -154,6 +155,7 @@ manual_torch_name_rule_map: dict[
type[UserFunctionVariable],
],
] = {
"torch.fx.traceback.annotate": UserFunctionVariable,
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
@ -3002,6 +3004,8 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
continue
obj = torch_dir + k[len("torch/") :]
if obj is not None:
if is_annotate_wrapped_function(obj):
obj = obj.__wrapped__
if is_lru_cache_wrapped_function(obj):
obj = obj.__wrapped__
if obj in d and d[obj] != v:

View File

@ -1101,6 +1101,14 @@ def is_lru_cache_wrapped_function(
)
def is_annotate_wrapped_function(
value: Any,
) -> bool:
return value == torch.fx.traceback.annotate and is_function(
inspect.getattr_static(value, "__wrapped__")
)
_FuncTypes: TypeAlias = Union[
types.FunctionType,
types.BuiltinFunctionType,

View File

@ -16,6 +16,7 @@ from .node import Node
log = logging.getLogger(__name__)
__all__ = [
"annotate",
"preserve_node_meta",
"has_preserved_node_meta",
"set_stack_trace",
@ -241,6 +242,62 @@ def set_stack_trace(stack: list[str]):
current_meta["stack_trace"] = "".join(stack)
@compatibility(is_backward_compatible=False)
@contextmanager
def annotate(annotation_dict: dict):
"""
Temporarily adds custom annotations to the current tracing context.
The fx_node produced from this tracing context will have the
custom annotations in node.metadata["custom"] field.
This context manager allows you to insert arbitrary metadata into the PT2
tracing system by updating the global `current_meta["custom"]` dictionary.
The annotations are automatically reverted after the context exits.
This is intended for advanced users who need to attach additional metadata to the fx nodes
(e.g., for debugging, analysis, or external tooling) during export tracing.
Note:
This API is **not backward compatible** and may evolve in future releases.
Note:
This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
to be used with PT2 family of tracers, e.g. torch.export and dynamo.
Args:
annotation_dict (dict): A dictionary of custom key-value pairs to inject
into the FX trace metadata.
Example:
>>> with annotate({"source": "custom_pass", "tag": 42}):
... # compute here
# After exiting the context, custom annotations are removed.
"""
global current_meta
has_custom = "custom" in current_meta
old_custom = {}
# cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here,
# as dynamo doesn't support copy.copy()
for k, v in current_meta.get("custom", {}).items():
old_custom[k] = v # noqa: PERF403
try:
if not has_custom:
current_meta["custom"] = {}
# Update with all key-value pairs from the input dict
current_meta["custom"].update(annotation_dict)
yield
finally:
if has_custom:
# Restore the original custom dict
current_meta["custom"] = old_custom
else:
del current_meta["custom"]
@compatibility(is_backward_compatible=False)
def set_grad_fn_seq_nr(seq_nr):
global current_meta