mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
Graph-Safe RNG State Exchange for Tensor Parallelism (#114068)
See #113541 The PR allows for registering and controlling multiple RNG states using indices, ensuring cudagraph-safe operations, and includes both C++ and Python API changes to support this functionality. cc @eellison @anijain2305 @jansel @ezyang @ptrblck @csarofeen @mcarilli Pull Request resolved: https://github.com/pytorch/pytorch/pull/114068 Approved by: https://github.com/ezyang, https://github.com/eqy, https://github.com/xuzhao9
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe41ba4765
commit
249e65b92d
@ -13722,6 +13722,58 @@ Example::
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.Generator.graphsafe_set_state,
|
||||
r"""
|
||||
Generator.graphsafe_set_state(state) -> None
|
||||
|
||||
Sets the state of the generator to the specified state in a manner that is safe for use in graph capture.
|
||||
This method is crucial for ensuring that the generator's state can be captured in the CUDA graph.
|
||||
|
||||
Arguments:
|
||||
state (torch.Generator): A Generator point to the new state for the generator, typically obtained from `graphsafe_get_state`.
|
||||
|
||||
Example:
|
||||
>>> g_cuda = torch.Generator(device='cuda')
|
||||
>>> g_cuda_other = torch.Generator(device='cuda')
|
||||
>>> current_state = g_cuda_other.graphsafe_get_state()
|
||||
>>> g_cuda.graphsafe_set_state(current_state)
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.Generator.graphsafe_get_state,
|
||||
r"""
|
||||
Generator.graphsafe_get_state() -> torch.Generator
|
||||
|
||||
Retrieves the current state of the generator in a manner that is safe for graph capture.
|
||||
This method is crucial for ensuring that the generator's state can be captured in the CUDA graph.
|
||||
|
||||
Returns:
|
||||
torch.Generator: A Generator point to the current state of the generator
|
||||
|
||||
Example:
|
||||
>>> g_cuda = torch.Generator(device='cuda')
|
||||
>>> current_state = g_cuda.graphsafe_get_state()
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.Generator.clone_state,
|
||||
r"""
|
||||
Generator.clone_state() -> torch.Generator
|
||||
|
||||
Clones the current state of the generator and returns a new generator pointing to this cloned state.
|
||||
This method is beneficial for preserving a particular state of a generator to restore at a later point.
|
||||
|
||||
Returns:
|
||||
torch.Generator: A Generator pointing to the newly cloned state.
|
||||
|
||||
Example:
|
||||
>>> g_cuda = torch.Generator(device='cuda')
|
||||
>>> cloned_state = g_cuda.clone_state()
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.Generator.manual_seed,
|
||||
|
||||
Reference in New Issue
Block a user