Use dataclass features in two classes (#164221)

This PR completes two TODO items by using features of `dataclass`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164221
Approved by: https://github.com/Skylion007, https://github.com/mlazos

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
Yuanyuan Chen
2025-10-01 03:20:39 +00:00
committed by PyTorch MergeBot
parent 591997490a
commit fa90090735
6 changed files with 23 additions and 12 deletions

View File

@ -30,7 +30,7 @@ class CallbackTests(TestCase):
def test_callbacks_with_duplicate_prevention(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with (
callback_handler.install_callbacks(trigger, compile_id),
callback_handler.install_callbacks(trigger, compile_id),
@ -40,7 +40,7 @@ class CallbackTests(TestCase):
def test_counter(self) -> None:
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with callback_handler.install_callbacks(trigger, compile_id):
self.assertEqual(
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
@ -56,7 +56,7 @@ class CallbackTests(TestCase):
AssertionError, "Pending callbacks counter cannot become negative."
):
trigger = CallbackTrigger.DYNAMO
compile_id = CompileId(0, 0)
compile_id = CompileId(frame_id=0, frame_compile_id=0)
with callback_handler.install_callbacks(trigger, str(compile_id)):
pass
self.assertEqual(

View File

@ -95,7 +95,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
transformed_code = code_map1[frame.f_code]
return wrap_guarded_code(
GuardedCode(
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
transformed_code,
empty_guard_manager,
CompileId(
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
),
)
)
return ConvertFrameReturn()
@ -105,7 +109,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
transformed_code = code_map2[frame.f_code]
return wrap_guarded_code(
GuardedCode(
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
transformed_code,
empty_guard_manager,
CompileId(
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
),
)
)
return ConvertFrameReturn()

View File

@ -329,7 +329,9 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
package=None,
)
with (
compile_context(CompileContext(CompileId(0, 0))),
compile_context(
CompileContext(CompileId(frame_id=0, frame_compile_id=0))
),
tracing(tracer.output.tracing_context),
tracer.set_current_tx(),
get_metrics_context(),

View File

@ -6864,7 +6864,9 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
with patch.object(
CompileContext,
"__init__",
lambda self, _: CompileContext_init(self, CompileId(999, 999)),
lambda self, _: CompileContext_init(
self, CompileId(frame_id=999, frame_compile_id=999)
),
):
_, (coda_a2,) = _run_and_get_stripped_kernels(a, x)
_, (coda_c2,) = _run_and_get_stripped_kernels(c, x)

View File

@ -72,8 +72,7 @@ CA_COMPILE_ID_PATTERN = re.compile(
# 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.
# TODO: mark as kw_only=True once we drop support for <Python 3.10
@dataclass(frozen=True)
@dataclass(frozen=True, kw_only=True, slots=True)
class CompileId:
frame_id: Optional[int]
# This id is per-frame, and counts how many times we've compiled this

View File

@ -1041,10 +1041,10 @@ def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float
return ms
@dataclasses.dataclass(slots=True)
class WhyNoFuse:
# TODO when we drop support for Python < 3.10, we can use
# @dataclass(slots=True) instead of manually specifying __slots__.
__slots__ = ["name1", "name2", "reason", "args"]
name1: str
name2: str
reason: str
args: tuple[Any, ...]