[pc] introduce ProgressiveCompilationState and clear callback (#157619)

followup from https://github.com/pytorch/pytorch/pull/157305 where
@aorenste correctly suggested clearing callback. this refactor
introduces a new dataclass so we don't need to check nullability for
each field

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157619
Approved by: https://github.com/aorenste
ghstack dependencies: #157305, #157614
This commit is contained in:
bobrenjc93
2025-07-04 21:30:13 -07:00
committed by PyTorch MergeBot
parent 5ea832e5f6
commit db00e1699a

View File

@ -33,6 +33,48 @@ class _PostCompileData:
graph_kwargs: _CompileFxKwargs graph_kwargs: _CompileFxKwargs
@dataclass
class ProgressiveCompilationState:
progression_futures: deque[Future[_WireProtocolPickledOutput]]
callback: Callable[[_WireProtocolPickledOutput], OutputCode]
post_compile_data: Optional[_PostCompileData]
def check_and_get_ready_stage(self) -> int:
"""Check if any progression stage is ready and return its index, or -1 if none are ready."""
if not self.progression_futures:
return -1
stage_index = -1
if self.post_compile_data:
for i, future in enumerate(self.progression_futures):
if future.done():
stage_index = i
return stage_index
def switch_to_progression_stage(self, stage_index: int) -> tuple[OutputCode, bool]:
"""
Switch to the specified progression stage and return the optimized output code.
Returns a tuple of (optimized_output_code, should_clear_compilation_state).
"""
future = self.progression_futures[stage_index]
assert future is not None
optimized_output_code = self.callback(future.result())
if pcd := self.post_compile_data:
optimized_output_code.post_compile(
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
)
# Clear earlier progression futures to free memory
for _ in range(stage_index + 1):
self.progression_futures.popleft()
# Return whether all compilation state should be cleared
should_clear_state = not self.progression_futures
return optimized_output_code, should_clear_state
# _AsyncOutputCode handles the actual management of waiting for an # _AsyncOutputCode handles the actual management of waiting for an
# out-of-process compile to finish and then switching over to it. # out-of-process compile to finish and then switching over to it.
@final @final
@ -192,9 +234,7 @@ class _AsyncFxCompile(FxCompile):
class _ProgressiveOutputCode(OutputCode): class _ProgressiveOutputCode(OutputCode):
_fast_output_code: Optional[OutputCode] _fast_output_code: Optional[OutputCode]
_optimized_output_code: Optional[OutputCode] _optimized_output_code: Optional[OutputCode]
_progression_futures: deque[Future[_WireProtocolPickledOutput]] _compilation_state: Optional[ProgressiveCompilationState]
_callback: Callable[[_WireProtocolPickledOutput], OutputCode]
_post_compile_data: Optional[_PostCompileData] = None
# _boxed_call state is effectively cached (we sometimes wrap unboxed w/ # _boxed_call state is effectively cached (we sometimes wrap unboxed w/
# lambdas to box them) so we can't change it mid-way. Since _boxed_call=True # lambdas to box them) so we can't change it mid-way. Since _boxed_call=True
# is more common let's default to that and we'll convert if necessary. # is more common let's default to that and we'll convert if necessary.
@ -211,8 +251,11 @@ class _ProgressiveOutputCode(OutputCode):
) -> None: ) -> None:
self._fast_output_code = fast_output_code self._fast_output_code = fast_output_code
self._optimized_output_code = None self._optimized_output_code = None
self._progression_futures = deque(progression_futures) self._compilation_state = ProgressiveCompilationState(
self._callback = callback progression_futures=deque(progression_futures),
callback=callback,
post_compile_data=None,
)
@override @override
def __call__(self, args: Sequence[Any]) -> Any: def __call__(self, args: Sequence[Any]) -> Any:
@ -235,15 +278,10 @@ class _ProgressiveOutputCode(OutputCode):
return res return res
def _check_and_switch_progression(self) -> None: def _check_and_switch_progression(self) -> None:
if not self._progression_futures: if not self._compilation_state:
return return
stage_index = -1 stage_index = self._compilation_state.check_and_get_ready_stage()
if self._post_compile_data:
for i, future in enumerate(self._progression_futures):
if future.done():
stage_index = i
if stage_index == -1: if stage_index == -1:
# no futures are ready # no futures are ready
return return
@ -251,24 +289,17 @@ class _ProgressiveOutputCode(OutputCode):
self._switch_to_progression_stage(stage_index) self._switch_to_progression_stage(stage_index)
def _switch_to_progression_stage(self, stage_index: int) -> None: def _switch_to_progression_stage(self, stage_index: int) -> None:
future = self._progression_futures[stage_index] assert self._compilation_state is not None
assert future is not None optimized_output_code, should_clear_state = (
optimized_output_code = self._callback(future.result()) self._compilation_state.switch_to_progression_stage(stage_index)
)
if pcd := self._post_compile_data:
# Only clear post_compile_data if this is the final progression stage
if stage_index == len(self._progression_futures) - 1:
self._post_compile_data = None
optimized_output_code.post_compile(
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
)
self._optimized_output_code = optimized_output_code self._optimized_output_code = optimized_output_code
self._fast_output_code = None self._fast_output_code = None
# Clear earlier progression futures to free memory # Clear all compilation state if no more progression futures are left
for _ in range(stage_index + 1): if should_clear_state:
self._progression_futures.popleft() self._compilation_state = None
@override @override
def post_compile( def post_compile(
@ -279,8 +310,10 @@ class _ProgressiveOutputCode(OutputCode):
) -> None: ) -> None:
assert self._fast_output_code is not None assert self._fast_output_code is not None
self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs) self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs)
# Store for later when optimized version is ready
self._post_compile_data = _PostCompileData( assert self._compilation_state is not None
# Store for later when optimized version is ready
self._compilation_state.post_compile_data = _PostCompileData(
example_inputs, constants, graph_kwargs example_inputs, constants, graph_kwargs
) )