mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[pc] introduce ProgressiveCompilationState and clear callback (#157619)
followup from https://github.com/pytorch/pytorch/pull/157305 where @aorenste correctly suggested clearing callback. this refactor introduces a new dataclass so we don't need to check nullability for each field Pull Request resolved: https://github.com/pytorch/pytorch/pull/157619 Approved by: https://github.com/aorenste ghstack dependencies: #157305, #157614
This commit is contained in:
committed by
PyTorch MergeBot
parent
5ea832e5f6
commit
db00e1699a
@ -33,6 +33,48 @@ class _PostCompileData:
|
|||||||
graph_kwargs: _CompileFxKwargs
|
graph_kwargs: _CompileFxKwargs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProgressiveCompilationState:
|
||||||
|
progression_futures: deque[Future[_WireProtocolPickledOutput]]
|
||||||
|
callback: Callable[[_WireProtocolPickledOutput], OutputCode]
|
||||||
|
post_compile_data: Optional[_PostCompileData]
|
||||||
|
|
||||||
|
def check_and_get_ready_stage(self) -> int:
|
||||||
|
"""Check if any progression stage is ready and return its index, or -1 if none are ready."""
|
||||||
|
if not self.progression_futures:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
stage_index = -1
|
||||||
|
if self.post_compile_data:
|
||||||
|
for i, future in enumerate(self.progression_futures):
|
||||||
|
if future.done():
|
||||||
|
stage_index = i
|
||||||
|
|
||||||
|
return stage_index
|
||||||
|
|
||||||
|
def switch_to_progression_stage(self, stage_index: int) -> tuple[OutputCode, bool]:
|
||||||
|
"""
|
||||||
|
Switch to the specified progression stage and return the optimized output code.
|
||||||
|
Returns a tuple of (optimized_output_code, should_clear_compilation_state).
|
||||||
|
"""
|
||||||
|
future = self.progression_futures[stage_index]
|
||||||
|
assert future is not None
|
||||||
|
optimized_output_code = self.callback(future.result())
|
||||||
|
|
||||||
|
if pcd := self.post_compile_data:
|
||||||
|
optimized_output_code.post_compile(
|
||||||
|
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear earlier progression futures to free memory
|
||||||
|
for _ in range(stage_index + 1):
|
||||||
|
self.progression_futures.popleft()
|
||||||
|
|
||||||
|
# Return whether all compilation state should be cleared
|
||||||
|
should_clear_state = not self.progression_futures
|
||||||
|
return optimized_output_code, should_clear_state
|
||||||
|
|
||||||
|
|
||||||
# _AsyncOutputCode handles the actual management of waiting for an
|
# _AsyncOutputCode handles the actual management of waiting for an
|
||||||
# out-of-process compile to finish and then switching over to it.
|
# out-of-process compile to finish and then switching over to it.
|
||||||
@final
|
@final
|
||||||
@ -192,9 +234,7 @@ class _AsyncFxCompile(FxCompile):
|
|||||||
class _ProgressiveOutputCode(OutputCode):
|
class _ProgressiveOutputCode(OutputCode):
|
||||||
_fast_output_code: Optional[OutputCode]
|
_fast_output_code: Optional[OutputCode]
|
||||||
_optimized_output_code: Optional[OutputCode]
|
_optimized_output_code: Optional[OutputCode]
|
||||||
_progression_futures: deque[Future[_WireProtocolPickledOutput]]
|
_compilation_state: Optional[ProgressiveCompilationState]
|
||||||
_callback: Callable[[_WireProtocolPickledOutput], OutputCode]
|
|
||||||
_post_compile_data: Optional[_PostCompileData] = None
|
|
||||||
# _boxed_call state is effectively cached (we sometimes wrap unboxed w/
|
# _boxed_call state is effectively cached (we sometimes wrap unboxed w/
|
||||||
# lambdas to box them) so we can't change it mid-way. Since _boxed_call=True
|
# lambdas to box them) so we can't change it mid-way. Since _boxed_call=True
|
||||||
# is more common let's default to that and we'll convert if necessary.
|
# is more common let's default to that and we'll convert if necessary.
|
||||||
@ -211,8 +251,11 @@ class _ProgressiveOutputCode(OutputCode):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._fast_output_code = fast_output_code
|
self._fast_output_code = fast_output_code
|
||||||
self._optimized_output_code = None
|
self._optimized_output_code = None
|
||||||
self._progression_futures = deque(progression_futures)
|
self._compilation_state = ProgressiveCompilationState(
|
||||||
self._callback = callback
|
progression_futures=deque(progression_futures),
|
||||||
|
callback=callback,
|
||||||
|
post_compile_data=None,
|
||||||
|
)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def __call__(self, args: Sequence[Any]) -> Any:
|
def __call__(self, args: Sequence[Any]) -> Any:
|
||||||
@ -235,15 +278,10 @@ class _ProgressiveOutputCode(OutputCode):
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
def _check_and_switch_progression(self) -> None:
|
def _check_and_switch_progression(self) -> None:
|
||||||
if not self._progression_futures:
|
if not self._compilation_state:
|
||||||
return
|
return
|
||||||
|
|
||||||
stage_index = -1
|
stage_index = self._compilation_state.check_and_get_ready_stage()
|
||||||
if self._post_compile_data:
|
|
||||||
for i, future in enumerate(self._progression_futures):
|
|
||||||
if future.done():
|
|
||||||
stage_index = i
|
|
||||||
|
|
||||||
if stage_index == -1:
|
if stage_index == -1:
|
||||||
# no futures are ready
|
# no futures are ready
|
||||||
return
|
return
|
||||||
@ -251,24 +289,17 @@ class _ProgressiveOutputCode(OutputCode):
|
|||||||
self._switch_to_progression_stage(stage_index)
|
self._switch_to_progression_stage(stage_index)
|
||||||
|
|
||||||
def _switch_to_progression_stage(self, stage_index: int) -> None:
|
def _switch_to_progression_stage(self, stage_index: int) -> None:
|
||||||
future = self._progression_futures[stage_index]
|
assert self._compilation_state is not None
|
||||||
assert future is not None
|
optimized_output_code, should_clear_state = (
|
||||||
optimized_output_code = self._callback(future.result())
|
self._compilation_state.switch_to_progression_stage(stage_index)
|
||||||
|
)
|
||||||
if pcd := self._post_compile_data:
|
|
||||||
# Only clear post_compile_data if this is the final progression stage
|
|
||||||
if stage_index == len(self._progression_futures) - 1:
|
|
||||||
self._post_compile_data = None
|
|
||||||
optimized_output_code.post_compile(
|
|
||||||
pcd.example_inputs, pcd.constants, pcd.graph_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self._optimized_output_code = optimized_output_code
|
self._optimized_output_code = optimized_output_code
|
||||||
self._fast_output_code = None
|
self._fast_output_code = None
|
||||||
|
|
||||||
# Clear earlier progression futures to free memory
|
# Clear all compilation state if no more progression futures are left
|
||||||
for _ in range(stage_index + 1):
|
if should_clear_state:
|
||||||
self._progression_futures.popleft()
|
self._compilation_state = None
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def post_compile(
|
def post_compile(
|
||||||
@ -279,8 +310,10 @@ class _ProgressiveOutputCode(OutputCode):
|
|||||||
) -> None:
|
) -> None:
|
||||||
assert self._fast_output_code is not None
|
assert self._fast_output_code is not None
|
||||||
self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs)
|
self._fast_output_code.post_compile(example_inputs, constants, graph_kwargs)
|
||||||
# Store for later when optimized version is ready
|
|
||||||
self._post_compile_data = _PostCompileData(
|
assert self._compilation_state is not None
|
||||||
|
# Store for later when optimized version is ready
|
||||||
|
self._compilation_state.post_compile_data = _PostCompileData(
|
||||||
example_inputs, constants, graph_kwargs
|
example_inputs, constants, graph_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user