mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use dataclass features in two classes (#164221)
This PR completes two TODO items by using features of `dataclass`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164221 Approved by: https://github.com/Skylion007, https://github.com/mlazos Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
591997490a
commit
fa90090735
@ -30,7 +30,7 @@ class CallbackTests(TestCase):
|
||||
|
||||
def test_callbacks_with_duplicate_prevention(self) -> None:
|
||||
trigger = CallbackTrigger.DYNAMO
|
||||
compile_id = CompileId(0, 0)
|
||||
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||
with (
|
||||
callback_handler.install_callbacks(trigger, compile_id),
|
||||
callback_handler.install_callbacks(trigger, compile_id),
|
||||
@ -40,7 +40,7 @@ class CallbackTests(TestCase):
|
||||
|
||||
def test_counter(self) -> None:
|
||||
trigger = CallbackTrigger.DYNAMO
|
||||
compile_id = CompileId(0, 0)
|
||||
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||
with callback_handler.install_callbacks(trigger, compile_id):
|
||||
self.assertEqual(
|
||||
callback_handler._CompilationCallbackHandler__pending_callbacks_counter,
|
||||
@ -56,7 +56,7 @@ class CallbackTests(TestCase):
|
||||
AssertionError, "Pending callbacks counter cannot become negative."
|
||||
):
|
||||
trigger = CallbackTrigger.DYNAMO
|
||||
compile_id = CompileId(0, 0)
|
||||
compile_id = CompileId(frame_id=0, frame_compile_id=0)
|
||||
with callback_handler.install_callbacks(trigger, str(compile_id)):
|
||||
pass
|
||||
self.assertEqual(
|
||||
|
@ -95,7 +95,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
|
||||
transformed_code = code_map1[frame.f_code]
|
||||
return wrap_guarded_code(
|
||||
GuardedCode(
|
||||
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
|
||||
transformed_code,
|
||||
empty_guard_manager,
|
||||
CompileId(
|
||||
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
|
||||
),
|
||||
)
|
||||
)
|
||||
return ConvertFrameReturn()
|
||||
@ -105,7 +109,11 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
|
||||
transformed_code = code_map2[frame.f_code]
|
||||
return wrap_guarded_code(
|
||||
GuardedCode(
|
||||
transformed_code, empty_guard_manager, CompileId(None, 0, 0)
|
||||
transformed_code,
|
||||
empty_guard_manager,
|
||||
CompileId(
|
||||
frame_id=None, frame_compile_id=0, compiled_autograd_id=0
|
||||
),
|
||||
)
|
||||
)
|
||||
return ConvertFrameReturn()
|
||||
|
@ -329,7 +329,9 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
|
||||
package=None,
|
||||
)
|
||||
with (
|
||||
compile_context(CompileContext(CompileId(0, 0))),
|
||||
compile_context(
|
||||
CompileContext(CompileId(frame_id=0, frame_compile_id=0))
|
||||
),
|
||||
tracing(tracer.output.tracing_context),
|
||||
tracer.set_current_tx(),
|
||||
get_metrics_context(),
|
||||
|
@ -6864,7 +6864,9 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||
with patch.object(
|
||||
CompileContext,
|
||||
"__init__",
|
||||
lambda self, _: CompileContext_init(self, CompileId(999, 999)),
|
||||
lambda self, _: CompileContext_init(
|
||||
self, CompileId(frame_id=999, frame_compile_id=999)
|
||||
),
|
||||
):
|
||||
_, (coda_a2,) = _run_and_get_stripped_kernels(a, x)
|
||||
_, (coda_c2,) = _run_and_get_stripped_kernels(c, x)
|
||||
|
@ -72,8 +72,7 @@ CA_COMPILE_ID_PATTERN = re.compile(
|
||||
# 3. Compact: The string form is directly displayed by some tools. Special symbols are okay.
|
||||
|
||||
|
||||
# TODO: mark as kw_only=True once we drop support for <Python 3.10
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, kw_only=True, slots=True)
|
||||
class CompileId:
|
||||
frame_id: Optional[int]
|
||||
# This id is per-frame, and counts how many times we've compiled this
|
||||
|
@ -1041,10 +1041,10 @@ def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float
|
||||
return ms
|
||||
|
||||
|
||||
@dataclasses.dataclass(slots=True)
|
||||
class WhyNoFuse:
|
||||
# TODO when we drop support for Python < 3.10, we can use
|
||||
# @dataclass(slots=True) instead of manually specifying __slots__.
|
||||
__slots__ = ["name1", "name2", "reason", "args"]
|
||||
name1: str
|
||||
name2: str
|
||||
reason: str
|
||||
args: tuple[Any, ...]
|
||||
|
||||
|
Reference in New Issue
Block a user