mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Improve the typing related to model
input argument of torch.compile()
(#153559)
Summary: Match the `overload` typing with the original typing in function definition and adjust the corresponding comments. Test Plan: contbuild & OSS CI Differential Revision: D74746243 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153559 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2f6c6df1d
commit
756fd80734
@ -2448,7 +2448,7 @@ def compile(
|
||||
|
||||
|
||||
def compile(
|
||||
model: _Optional[_Callable] = None,
|
||||
model: _Optional[_Callable[_InputT, _RetT]] = None,
|
||||
*,
|
||||
fullgraph: builtins.bool = False,
|
||||
dynamic: _Optional[builtins.bool] = None,
|
||||
@ -2479,7 +2479,7 @@ def compile(
|
||||
function, they will all share the same code cache.
|
||||
|
||||
Args:
|
||||
model (Callable): Module/function to optimize
|
||||
model (Callable or None): Module/function to optimize
|
||||
fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions
|
||||
in the function that it will optimize. If True, then we require that the entire function be
|
||||
capturable into a single graph. If this is not possible (that is, if there are graph breaks),
|
||||
|
Reference in New Issue
Block a user