mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[pc] introduce ProgressiveCompilationState and clear callback (#157619)
followup from https://github.com/pytorch/pytorch/pull/157305 where @aorenste correctly suggested clearing callback. this refactor introduces a new dataclass so we don't need to check nullability for each field Pull Request resolved: https://github.com/pytorch/pytorch/pull/157619 Approved by: https://github.com/aorenste ghstack dependencies: #157305, #157614
This commit is contained in:
committed by
PyTorch MergeBot
parent
5ea832e5f6
commit
db00e1699a
@ -33,6 +33,48 @@ class _PostCompileData:
|
||||
graph_kwargs: _CompileFxKwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressiveCompilationState:
|
||||
progression_futures: deque[Future[_WireProtocolPickledOutput]]
|
||||
callback: Callable[[_WireProtocolPickledOutput], OutputCode]
|
||||
post_compile_data: Optional[_PostCompileData]
|
||||
|
||||
def check_and_get_ready_stage(self) -> int:
|
||||
"""Check if any progression stage is ready and return its index, or -1 if none are ready."""
|
||||
if not self.progression_futures:
|
||||
return -1
|
||||
|
||||
stage_index = -1
|
||||
if self.post_compile_data:
|
||||
for i, future in enumerate(self.progression_futures):
|
||||
if future.done():
|
||||
stage_index = i
|
||||
|
||||
return stage_index
|
||||
|
||||
def switch_to_progression_stage(self, stage_index: int) -> tuple[OutputCode, bool]:
|
||||
"""
|
||||
Switch to the specified progression stage and return the optimized output code.
|
||||
Returns a tuple of (optimized_output_code, should_clear_compilation_state).
|
||||
"""
|
||||
future = self.progression_futures[stage_index]
|
||||
assert future is not None
|
||||
optimized_output_code = self.callback(future.result())
|
||||
|
||||
if pcd := self.post_compile_data:
|
||||
optimized_output_code.post_compile(
|
||||
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
|
||||
)
|
||||
|
||||
# Clear earlier progression futures to free memory
|
||||
for _ in range(stage_index + 1):
|
||||
self.progression_futures.popleft()
|
||||
|
||||
# Return whether all compilation state should be cleared
|
||||
should_clear_state = not self.progression_futures
|
||||
return optimized_output_code, should_clear_state
|
||||
|
||||
|
||||
# _AsyncOutputCode handles the actual management of waiting for an
|
||||
# out-of-process compile to finish and then switching over to it.
|
||||
@final
|
||||
@ -192,9 +234,7 @@ class _AsyncFxCompile(FxCompile):
|
||||
class _ProgressiveOutputCode(OutputCode):
|
||||
_fast_output_code: Optional[OutputCode]
|
||||
_optimized_output_code: Optional[OutputCode]
|
||||
_progression_futures: deque[Future[_WireProtocolPickledOutput]]
|
||||
_callback: Callable[[_WireProtocolPickledOutput], OutputCode]
|
||||
_post_compile_data: Optional[_PostCompileData] = None
|
||||
_compilation_state: Optional[ProgressiveCompilationState]
|
||||
# _boxed_call state is effectively cached (we sometimes wrap unboxed w/
|
||||
# lambdas to box them) so we can't change it mid-way. Since _boxed_call=True
|
||||
# is more common let's default to that and we'll convert if necessary.
|
||||
@ -211,8 +251,11 @@ class _ProgressiveOutputCode(OutputCode):
|
||||
) -> None:
|
||||
self._fast_output_code = fast_output_code
|
||||
self._optimized_output_code = None
|
||||
self._progression_futures = deque(progression_futures)
|
||||
self._callback = callback
|
||||
self._compilation_state = ProgressiveCompilationState(
|
||||
progression_futures=deque(progression_futures),
|
||||
callback=callback,
|
||||
post_compile_data=None,
|
||||
)
|
||||
|
||||
@override
|
||||
def __call__(self, args: Sequence[Any]) -> Any:
|
||||
@ -235,15 +278,10 @@ class _ProgressiveOutputCode(OutputCode):
|
||||
return res
|
||||
|
||||
def _check_and_switch_progression(self) -> None:
|
||||
if not self._progression_futures:
|
||||
if not self._compilation_state:
|
||||
return
|
||||
|
||||
stage_index = -1
|
||||
if self._post_compile_data:
|
||||
for i, future in enumerate(self._progression_futures):
|
||||
if future.done():
|
||||
stage_index = i
|
||||
|
||||
stage_index = self._compilation_state.check_and_get_ready_stage()
|
||||
if stage_index == -1:
|
||||
# no futures are ready
|
||||
return
|
||||
@ -251,24 +289,17 @@ class _ProgressiveOutputCode(OutputCode):
|
||||
self._switch_to_progression_stage(stage_index)
|
||||
|
||||
def _switch_to_progression_stage(self, stage_index: int) -> None:
|
||||
future = self._progression_futures[stage_index]
|
||||
assert future is not None
|
||||
optimized_output_code = self._callback(future.result())
|
||||
|
||||
if pcd := self._post_compile_data:
|
||||
# Only clear post_compile_data if this is the final progression stage
|
||||
if stage_index == len(self._progression_futures) - 1:
|
||||
self._post_compile_data = None
|
||||
optimized_output_code.post_compile(
|
||||
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
|
||||
)
|
||||
assert self._compilation_state is not None
|
||||
optimized_output_code, should_clear_state = (
|
||||
self._compilation_state.switch_to_progression_stage(stage_index)
|
||||
)
|
||||
|
||||
self._optimized_output_code = optimized_output_code
|
||||
self._fast_output_code = None
|
||||
|
||||
# Clear earlier progression futures to free memory
|
||||
for _ in range(stage_index + 1):
|
||||
self._progression_futures.popleft()
|
||||
# Clear all compilation state if no more progression futures are left
|
||||
if should_clear_state:
|
||||
self._compilation_state = None
|
||||
|
||||
@override
|
||||
def post_compile(
|
||||
@ -279,8 +310,10 @@ class _ProgressiveOutputCode(OutputCode):
|
||||
) -> None:
|
||||
assert self._fast_output_code is not None
|
||||
self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs)
|
||||
# Store for later when optimized version is ready
|
||||
self._post_compile_data = _PostCompileData(
|
||||
|
||||
assert self._compilation_state is not None
|
||||
# Store for later when optimized version is ready
|
||||
self._compilation_state.post_compile_data = _PostCompileData(
|
||||
example_inputs, constants, graph_kwargs
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user