mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[dynamo] Support torch.fx.traceback.annotate (#164678)"
This reverts commit 801e282f39e9ef4424dfd3ecfd2b550a44595229. Reverted https://github.com/pytorch/pytorch/pull/164678 on behalf of https://github.com/izaitsevfb due to breaks executorch internally, see [D84068062](https://www.internalfb.com/diff/D84068062?entry_point=16) ([comment](https://github.com/pytorch/pytorch/pull/164678#issuecomment-3379281844))
This commit is contained in:
@ -1,259 +0,0 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch.fx.traceback as fx_traceback
|
||||
import torch.utils.checkpoint
|
||||
from torch._dynamo.test_case import run_tests
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
|
||||
|
||||
def checkpoint_wrapper(fn):
|
||||
def inner(*args):
|
||||
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||
def get_custom_metadata(self, gm):
|
||||
def helper(gm):
|
||||
custom_metadata = []
|
||||
for node in gm.graph.nodes:
|
||||
if hasattr(node, "meta") and node.meta.get("custom", None):
|
||||
custom_metadata.append((node.op, node.name, node.meta["custom"]))
|
||||
if node.op == "get_attr" and isinstance(
|
||||
getattr(gm, node.target), torch.fx.GraphModule
|
||||
):
|
||||
custom_metadata.append(helper(getattr(gm, node.target)))
|
||||
return custom_metadata
|
||||
|
||||
return "\n".join(str(x) for x in helper(gm))
|
||||
|
||||
def test_annotations(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
with fx_traceback.annotate({"fdsp_bucket": 0}):
|
||||
sin = torch.sin(x)
|
||||
sub = sin - 2
|
||||
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
|
||||
mul = sub * 2
|
||||
div = mul / 3
|
||||
return div
|
||||
|
||||
m = Mod()
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_m = torch.compile(m, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_m(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sub', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""\
|
||||
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'sub', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'mul_1', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})
|
||||
('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0})
|
||||
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_activation_checkpointing(self):
|
||||
@checkpoint_wrapper
|
||||
def gn(x):
|
||||
return torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
with fx_traceback.annotate({"ac_sin": 0}):
|
||||
ac = gn(x)
|
||||
return torch.sigmoid(ac)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_x_', {'ac_sin': 0})
|
||||
('get_attr', 'wrap_body_0', {'ac_sin': 0})
|
||||
[('placeholder', 'l_x_', {'ac_sin': 0}), ('call_function', 'sin', {'ac_sin': 0}), ('output', 'output', {'ac_sin': 0})]
|
||||
('call_function', 'tag_activation_checkpoint', {'ac_sin': 0})
|
||||
('call_function', 'ac', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""('call_function', 'sin', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'cos', {'ac_sin': 0})
|
||||
('call_function', 'mul', {'ac_sin': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_activation_checkpointing_annotation_inside(self):
|
||||
@checkpoint_wrapper
|
||||
def gn(x):
|
||||
x = x + 1
|
||||
with fx_traceback.annotate({"stage": 0}):
|
||||
p = torch.sin(x)
|
||||
return p + 1
|
||||
|
||||
def fn(x):
|
||||
ac = gn(x)
|
||||
return torch.sigmoid(ac)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""('call_function', 'sin', {'stage': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('call_function', 'cos', {'stage': 0})
|
||||
('call_function', 'mul', {'stage': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
@requires_cuda_and_triton
|
||||
def test_ac_flex_attention(self):
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
def mask_mod(b, h, q, k):
|
||||
return q >= 0
|
||||
|
||||
a = 12
|
||||
b = 64
|
||||
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b)
|
||||
|
||||
def gn(x: torch.Tensor):
|
||||
with fx_traceback.annotate({"compile_inductor": 0}):
|
||||
return flex_attention(
|
||||
x, x, x, block_mask=block_mask, score_mod=_squared
|
||||
)
|
||||
|
||||
def fn(x):
|
||||
x = torch.sin(x)
|
||||
x = gn(x)
|
||||
return torch.cos(x)
|
||||
|
||||
x = torch.randn(
|
||||
1,
|
||||
1,
|
||||
a * b,
|
||||
b,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
opt_fn(x).sum().backward()
|
||||
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
|
||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
||||
self.assertExpectedInline(
|
||||
str(dynamo_metadata),
|
||||
"""\
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_kv_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_kv_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_q_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_q_indices', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_q_num_blocks', {'compile_inductor': 0})
|
||||
('placeholder', 'l_gn_closure_1_cell_contents_full_q_indices', {'compile_inductor': 0})
|
||||
('get_attr', 'score_mod_0', {'compile_inductor': 0})
|
||||
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('placeholder', 'child_4', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('get_attr', 'mask_fn_0', {'compile_inductor': 0})
|
||||
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention', {'compile_inductor': 0})
|
||||
('call_function', 'out', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(fw_metadata),
|
||||
"""\
|
||||
('get_attr', 'sdpa_score0', {'compile_inductor': 0})
|
||||
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('get_attr', 'sdpa_mask0', {'compile_inductor': 0})
|
||||
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention', {'compile_inductor': 0})
|
||||
('call_function', 'getitem', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_1', {'compile_inductor': 0})
|
||||
('call_function', 'detach_1', {'compile_inductor': 0})
|
||||
('call_function', 'detach_4', {'compile_inductor': 0})
|
||||
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
str(bw_metadata),
|
||||
"""\
|
||||
('placeholder', 'getitem', {'compile_inductor': 0})
|
||||
('placeholder', 'detach_5', {'compile_inductor': 0})
|
||||
('call_function', 'zeros', {'compile_inductor': 0})
|
||||
('call_function', 'detach', {'compile_inductor': 0})
|
||||
('call_function', 'detach_2', {'compile_inductor': 0})
|
||||
('call_function', 'detach_3', {'compile_inductor': 0})
|
||||
('get_attr', 'fw_graph0', {'compile_inductor': 0})
|
||||
[]
|
||||
('get_attr', 'joint_graph0', {'compile_inductor': 0})
|
||||
[]
|
||||
('get_attr', 'mask_graph0', {'compile_inductor': 0})
|
||||
[('call_function', 'ge', {'compile_inductor': 0})]
|
||||
('call_function', 'flex_attention_backward', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_3', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_4', {'compile_inductor': 0})
|
||||
('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -15668,6 +15668,11 @@ def forward(self, x):
|
||||
test_serdes=True,
|
||||
)
|
||||
|
||||
@testing.expectedFailureTrainingIRToRunDecomp
|
||||
@testing.expectedFailureRetraceability
|
||||
@testing.expectedFailureStrictV2
|
||||
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
|
||||
@testing.expectedFailureSerDer
|
||||
def test_preserve_annotation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
|
@ -1490,7 +1490,7 @@ class TestMaxAutotune(TestCase):
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_.*.run"
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(out, a @ b, atol=1e-2, rtol=1e-2)
|
||||
@ -1504,7 +1504,7 @@ class TestMaxAutotune(TestCase):
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_.*.run"
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
@ -1525,7 +1525,7 @@ class TestMaxAutotune(TestCase):
|
||||
).run(code[0])
|
||||
else:
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_.*.run"
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").run(code[0])
|
||||
check_divisors(code)
|
||||
torch.testing.assert_close(
|
||||
@ -1576,7 +1576,7 @@ class TestMaxAutotune(TestCase):
|
||||
|
||||
out, code = run_and_get_code(compiled_func, a, b)
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_.*.run"
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex(
|
||||
r"2\*s[0-9]+"
|
||||
).check_regex("s[0-9]+ = 32").run(code[0])
|
||||
@ -1626,7 +1626,7 @@ class TestMaxAutotune(TestCase):
|
||||
out.backward()
|
||||
|
||||
FileCheck().check("extern_kernels.bmm_dtype").check_regex(
|
||||
"triton_.*_fused_.*.run"
|
||||
"triton_.*_fused_0.run"
|
||||
).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex(
|
||||
r"256\*s[0-9]+"
|
||||
).check_regex("s[0-9]+ = 8").run(
|
||||
|
@ -750,7 +750,6 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
|
||||
|
||||
|
||||
@preserve_global_state
|
||||
@torch.fx.traceback.preserve_node_meta()
|
||||
def trace_frame(
|
||||
code: types.CodeType,
|
||||
globals: dict[str, object],
|
||||
|
@ -51,6 +51,7 @@ from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
||||
from .utils import (
|
||||
getfile,
|
||||
hashable,
|
||||
is_annotate_wrapped_function,
|
||||
is_lru_cache_wrapped_function,
|
||||
NP_SUPPORTED_MODULES,
|
||||
unwrap_if_wrapper,
|
||||
@ -154,6 +155,7 @@ manual_torch_name_rule_map: dict[
|
||||
type[UserFunctionVariable],
|
||||
],
|
||||
] = {
|
||||
"torch.fx.traceback.annotate": UserFunctionVariable,
|
||||
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
|
||||
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
|
||||
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
|
||||
@ -3002,6 +3004,8 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
|
||||
continue
|
||||
obj = torch_dir + k[len("torch/") :]
|
||||
if obj is not None:
|
||||
if is_annotate_wrapped_function(obj):
|
||||
obj = obj.__wrapped__
|
||||
if is_lru_cache_wrapped_function(obj):
|
||||
obj = obj.__wrapped__
|
||||
if obj in d and d[obj] != v:
|
||||
@ -3433,7 +3437,6 @@ MOD_INLINELIST = [
|
||||
"torch.fx._symbolic_trace",
|
||||
"torch.fx.experimental.proxy_tensor",
|
||||
"torch.fx.passes.shape_prop",
|
||||
"torch.fx.traceback",
|
||||
"torch.nn",
|
||||
"torch.overrides",
|
||||
"torch.random",
|
||||
|
@ -1108,6 +1108,14 @@ def is_lru_cache_wrapped_function(
|
||||
)
|
||||
|
||||
|
||||
def is_annotate_wrapped_function(
|
||||
value: Any,
|
||||
) -> bool:
|
||||
return value == torch.fx.traceback.annotate and is_function(
|
||||
inspect.getattr_static(value, "__wrapped__")
|
||||
)
|
||||
|
||||
|
||||
_FuncTypes: TypeAlias = Union[
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
|
@ -29,7 +29,6 @@ from .ctx_manager import (
|
||||
DynamoConfigPatchVariable,
|
||||
ErrorOnGraphBreakVariable,
|
||||
FSDPParamGroupUseTrainingStateVariable,
|
||||
FxTracebackAnnotateVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
GradModeVariable,
|
||||
|
@ -1262,34 +1262,6 @@ class SDPAKernelVariable(ContextWrappingVariable):
|
||||
return "_sdpa_kernel_variadic"
|
||||
|
||||
|
||||
class FxTracebackAnnotateVariable(ContextWrappingVariable):
|
||||
"""
|
||||
fx.traceback.annotate is a context manager that allows users to annotate the
|
||||
fx graph nodes with custom metadata. In the context of Dynamo, we don't have
|
||||
to trace the body of the context manager. Instead we want to directly run
|
||||
the body of the context manager, so the Dynamo created Fx graphs have the
|
||||
right custom metadata. This variable tracker just runs __enter__ and
|
||||
__exit__ method (instead of tracing).
|
||||
"""
|
||||
|
||||
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
|
||||
super().__init__(
|
||||
target_values=target_values, initial_values=initial_values, **kwargs
|
||||
)
|
||||
|
||||
def enter(self, tx, *args):
|
||||
cm = torch.fx.traceback.annotate(self.target_values)
|
||||
cm.__enter__()
|
||||
self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None))
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def module_name(self):
|
||||
return "torch.fx.traceback"
|
||||
|
||||
def fn_name(self):
|
||||
return "annotate"
|
||||
|
||||
|
||||
class StreamVariable(VariableTracker):
|
||||
def __init__(self, proxy, value, device, **kwargs) -> None:
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
|
@ -125,7 +125,6 @@ supported_ctx_manager_classes = dict.fromkeys(
|
||||
torch.autograd.graph.disable_saved_tensors_hooks,
|
||||
torch.cpu.amp.autocast_mode.autocast,
|
||||
torch.cuda.amp.autocast_mode.autocast,
|
||||
torch.fx.traceback.annotate,
|
||||
# We'll let Dynamo inline into the contextlib part of these context
|
||||
# manager instances, all the way till it invokes the wrapped function
|
||||
# itself (at which point we wrap it back to special context manager
|
||||
@ -326,7 +325,6 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
DisabledSavedTensorsHooksVariable,
|
||||
DualLevelContextManager,
|
||||
FSDPParamGroupUseTrainingStateVariable,
|
||||
FxTracebackAnnotateVariable,
|
||||
GradIncrementNestingCtxManagerVariable,
|
||||
GradInplaceRequiresGradCtxManagerVariable,
|
||||
GradModeVariable,
|
||||
@ -361,11 +359,6 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
|
||||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
|
||||
return InferenceModeVariable.create(tx, inf_mode)
|
||||
elif self.value is torch.fx.traceback.annotate:
|
||||
assert len(args) <= 1 and len(kwargs) == 0
|
||||
return FxTracebackAnnotateVariable(
|
||||
args[0].as_python_constant(), source=self.source
|
||||
)
|
||||
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
|
||||
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
|
||||
|
||||
|
@ -273,7 +273,11 @@ def annotate(annotation_dict: dict):
|
||||
global current_meta
|
||||
|
||||
has_custom = "custom" in current_meta
|
||||
old_custom = copy.copy(current_meta.get("custom", {}))
|
||||
old_custom = {}
|
||||
# cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here,
|
||||
# as dynamo doesn't support copy.copy()
|
||||
for k, v in current_meta.get("custom", {}).items():
|
||||
old_custom[k] = v # noqa: PERF403
|
||||
|
||||
try:
|
||||
if not has_custom:
|
||||
|
Reference in New Issue
Block a user