mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support torch.fx.traceback.annotate (#164678)
Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation. The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node. What does not work? * We still have to set the context manager `torch.fx.traceback.preserve_node_meta()` in the user code because CI was unhappy. This can be fixed but with some perseverance. * This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678 Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
This commit is contained in:
committed by
PyTorch MergeBot
parent
94b1ec8c7c
commit
4308b8a28f
270
test/dynamo/test_fx_annotate.py
Normal file
270
test/dynamo/test_fx_annotate.py
Normal file
@ -0,0 +1,270 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.test_case import run_tests
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
def checkpoint_wrapper(fn):
|
||||
def inner(*args):
|
||||
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
# TODO - should not need this because we should turn this on in Dynamo but
|
||||
# for some reasons, test fail.
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.cm = torch.fx.traceback.preserve_node_meta()
|
||||
self.cm.__enter__()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
self.cm.__exit__(None, None, None)
|
||||
|
||||
def get_custom_metadata(self, gm):
|
||||
def helper(gm):
|
||||
custom_metadata = []
|
||||
for node in gm.graph.nodes:
|
||||
if hasattr(node, "meta") and node.meta.get("custom", None):
|
||||
custom_metadata.append((node.op, node.name, node.meta["custom"]))
|
||||
if node.op == "get_attr" and isinstance(
|
||||
getattr(gm, node.target), torch.fx.GraphModule
|
||||
):
|
||||
custom_metadata.append(helper(getattr(gm, node.target)))
|
||||
return custom_metadata
|
||||
|
||||
return "\n".join(str(x) for x in helper(gm))
|
||||
|
||||
def test_annotations(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
with fx_traceback.annotate({"fdsp_bucket": 0}):
|
||||
sin = torch.sin(x)
|
||||
sub = sin - 2
|
||||
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
|
||||
mul = sub * 2
|
||||
div = mul / 3
|
||||
return div
|
||||
|
||||
m = Mod()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_m = torch.compile(m, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_m(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sub', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""\
|
||||
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sub', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'mul_1', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})
|
||||
('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_activation_checkpointing(self):
|
||||
@checkpoint_wrapper
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
with fx_traceback.annotate({"ac_sin": 0}):
|
||||
ac = gn(x)
|
||||
return torch.sigmoid(ac)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_x_', {'ac_sin': 0})
|
||||
('get_attr', 'wrap_body_0', {'ac_sin': 0})
|
||||
[('placeholder', 'l_x_', {'ac_sin': 0}), ('call_function', 'sin', {'ac_sin': 0}), ('output', 'output', {'ac_sin': 0})]
|
||||
('call_function', 'tag_activation_checkpoint', {'ac_sin': 0})
|
||||
('call_function', 'ac', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""('call_function', 'sin', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'cos', {'ac_sin': 0})
|
||||
('call_function', 'mul', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_activation_checkpointing_annotation_inside(self):
|
||||
@checkpoint_wrapper
|
||||
def gn(x):
|
||||
x = x + 1
|
||||
with fx_traceback.annotate({"stage": 0}):
|
||||
p = torch.sin(x)
|
||||
return p + 1
|
||||
|
||||
def fn(x):
|
||||
ac = gn(x)
|
||||
return torch.sigmoid(ac)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""('call_function', 'sin', {'stage': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'cos', {'stage': 0})
|
||||
('call_function', 'mul', {'stage': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_ac_flex_attention(self):
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
def mask_mod(b, h, q, k):
|
||||
return q >= 0
|
||||
|
||||
a = 12
|
||||
b = 64
|
||||
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b)
|
||||
|
||||
def gn(x: torch.Tensor):
|
||||
with fx_traceback.annotate({"compile_inductor": 0}):
|
||||
return flex_attention(
|
||||
x, x, x, block_mask=block_mask, score_mod=_squared
|
||||
)
|
||||
|
||||
def fn(x):
|
||||
x = torch.sin(x)
|
||||
x = gn(x)
|
||||
return torch.cos(x)
|
||||
|
||||
x = torch.randn(
|
||||
1,
|
||||
1,
|
||||
a * b,
|
||||
b,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_kv_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_kv_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_q_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_q_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_q_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_q_indices', {'compile_inductor': 0})
|
||||
('get_attr', 'score_mod_0', {'compile_inductor': 0})
|
||||
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('placeholder', 'child_4', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('get_attr', 'mask_fn_0', {'compile_inductor': 0})
|
||||
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention', {'compile_inductor': 0})
|
||||
('call_function', 'out', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""\
|
||||
('get_attr', 'sdpa_score0', {'compile_inductor': 0})
|
||||
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('get_attr', 'sdpa_mask0', {'compile_inductor': 0})
|
||||
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention', {'compile_inductor': 0})
|
||||
('call_function', 'getitem', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_1', {'compile_inductor': 0})
|
||||
('call_function', 'detach_1', {'compile_inductor': 0})
|
||||
('call_function', 'detach_4', {'compile_inductor': 0})
|
||||
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('placeholder', 'getitem', {'compile_inductor': 0})
|
||||
('placeholder', 'detach_5', {'compile_inductor': 0})
|
||||
('call_function', 'zeros', {'compile_inductor': 0})
|
||||
('call_function', 'detach', {'compile_inductor': 0})
|
||||
('call_function', 'detach_2', {'compile_inductor': 0})
|
||||
('call_function', 'detach_3', {'compile_inductor': 0})
|
||||
('get_attr', 'fw_graph0', {'compile_inductor': 0})
|
||||
[]
|
||||
('get_attr', 'joint_graph0', {'compile_inductor': 0})
|
||||
[]
|
||||
('get_attr', 'mask_graph0', {'compile_inductor': 0})
|
||||
[('call_function', 'ge', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention_backward', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_3', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_4', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -15660,11 +15660,6 @@ def forward(self, x):
|
||||
test_serdes=True,
|
||||
)
|
||||
|
||||
@testing.expectedFailureTrainingIRToRunDecomp
|
||||
@testing.expectedFailureRetraceability
|
||||
@testing.expectedFailureStrictV2
|
||||
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
|
||||
@testing.expectedFailureSerDer
|
||||
def test_preserve_annotation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -750,6 +750,9 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
|
||||
return handle
|
||||
|
||||
|
||||
# TODO - We want to run preserve_node_meta context manager here, but the CI
|
||||
# fails (its unclear if the failures were flaky)
|
||||
# @torch.fx.traceback.preserve_node_meta()
|
||||
@preserve_global_state
|
||||
def trace_frame(
|
||||
code: types.CodeType,
|
||||
|
@ -51,7 +51,6 @@ from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
||||
from .utils import (
|
||||
getfile,
|
||||
hashable,
|
||||
is_annotate_wrapped_function,
|
||||
is_lru_cache_wrapped_function,
|
||||
NP_SUPPORTED_MODULES,
|
||||
unwrap_if_wrapper,
|
||||
@ -155,7 +154,6 @@ manual_torch_name_rule_map: dict[
|
||||
type[UserFunctionVariable],
|
||||
],
|
||||
] = {
|
||||
"torch.fx.traceback.annotate": UserFunctionVariable,
|
||||
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
|
||||
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
|
||||
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
|
||||
@ -2996,9 +2994,6 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
|
||||
continue
|
||||
obj = torch_dir + k[len("torch/") :]
|
||||
if obj is not None:
|
||||
if is_annotate_wrapped_function(obj):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
obj = obj.__wrapped__
|
||||
if is_lru_cache_wrapped_function(obj):
|
||||
obj = obj.__wrapped__
|
||||
if obj in d and d[obj] != v:
|
||||
@ -3430,6 +3425,7 @@ MOD_INLINELIST = [
|
||||
"torch.fx._symbolic_trace",
|
||||
"torch.fx.experimental.proxy_tensor",
|
||||
"torch.fx.passes.shape_prop",
|
||||
"torch.fx.traceback",
|
||||
"torch.nn",
|
||||
"torch.overrides",
|
||||
"torch.random",
|
||||
|
@ -1111,14 +1111,6 @@ def is_lru_cache_wrapped_function(
|
||||
)
|
||||
|
||||
|
||||
def is_annotate_wrapped_function(
|
||||
value: Any,
|
||||
) -> bool:
|
||||
return value == torch.fx.traceback.annotate and is_function(
|
||||
inspect.getattr_static(value, "__wrapped__")
|
||||
)
|
||||
|
||||
|
||||
_FuncTypes: TypeAlias = Union[
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
|
@ -29,6 +29,7 @@ from .ctx_manager import (
|
||||
DynamoConfigPatchVariable,
|
||||
ErrorOnGraphBreakVariable,
|
||||
FSDPParamGroupUseTrainingStateVariable,
|
||||
FxTracebackAnnotateVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
GradModeVariable,
|
||||
|
@ -1262,6 +1262,34 @@ class SDPAKernelVariable(ContextWrappingVariable):
|
||||
return "_sdpa_kernel_variadic"
|
||||
|
||||
|
||||
class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
||||
"""
|
||||
fx.traceback.annotate is a context manager that allows users to annotate the
|
||||
fx graph nodes with custom metadata. In the context of Dynamo, we don't have
|
||||
to trace the body of the context manager. Instead we want to directly run
|
||||
the body of the context manager, so the Dynamo created Fx graphs have the
|
||||
right custom metadata. This variable tracker just runs __enter__ and
|
||||
__exit__ method (instead of tracing).
|
||||
"""
|
||||
|
||||
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
|
||||
def enter(self, tx, *args):
|
||||
cm = torch.fx.traceback.annotate(self.target_values)
|
||||
cm.__enter__()
|
||||
self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None))
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def module_name(self):
|
||||
return "torch.fx.traceback"
|
||||
|
||||
def fn_name(self):
|
||||
return "annotate"
|
||||
|
||||
|
||||
class StreamVariable(VariableTracker):
|
||||
def __init__(self, proxy, value, device, **kwargs) -> None:
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
|
@ -125,6 +125,7 @@ supported_ctx_manager_classes = dict.fromkeys(
|
||||
torch.autograd.graph.disable_saved_tensors_hooks,
|
||||
torch.cpu.amp.autocast_mode.autocast,
|
||||
torch.cuda.amp.autocast_mode.autocast,
|
||||
torch.fx.traceback.annotate,
|
||||
# We'll let Dynamo inline into the contextlib part of these context
|
||||
# manager instances, all the way till it invokes the wrapped function
|
||||
# itself (at which point we wrap it back to special context manager
|
||||
@ -325,6 +326,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
DisabledSavedTensorsHooksVariable,
|
||||
DualLevelContextManager,
|
||||
FSDPParamGroupUseTrainingStateVariable,
|
||||
FxTracebackAnnotateVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
GradModeVariable,
|
||||
@ -359,6 +361,11 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
|
||||
return InferenceModeVariable.create(tx, inf_mode)
|
||||
elif self.value is torch.fx.traceback.annotate:
|
||||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
return FxTracebackAnnotateVariable(
|
||||
args[0].as_python_constant(), source=self.source
|
||||
)
|
||||
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
|
||||
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
|
||||
|
||||
|
@ -273,11 +273,7 @@ def annotate(annotation_dict: dict):
|
||||
global current_meta
|
||||
|
||||
has_custom = "custom" in current_meta
|
||||
old_custom = {}
|
||||
# cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here,
|
||||
# as dynamo doesn't support copy.copy()
|
||||
for k, v in current_meta.get("custom", {}).items():
|
||||
old_custom[k] = v # noqa: PERF403
|
||||
old_custom = copy.copy(current_meta.get("custom", {}))
|
||||
|
||||
try:
|
||||
if not has_custom:
|
||||
|
Reference in New Issue
Block a user