Graph-Safe RNG State Exchange for Tensor Parallelism (#114068)

See #113541

The PR allows for registering and controlling multiple RNG states using indices, ensuring cudagraph-safe operations, and includes both C++ and Python API changes to support this functionality.

cc  @eellison @anijain2305 @jansel @ezyang @ptrblck @csarofeen @mcarilli
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114068
Approved by: https://github.com/ezyang, https://github.com/eqy, https://github.com/xuzhao9
This commit is contained in:
Frank Lin
2024-03-27 01:14:38 +00:00
committed by PyTorch MergeBot
parent fe41ba4765
commit 249e65b92d
15 changed files with 644 additions and 139 deletions

View File

@ -13722,6 +13722,58 @@ Example::
""",
)
add_docstr(
torch.Generator.graphsafe_set_state,
r"""
Generator.graphsafe_set_state(state) -> None
Sets the state of the generator to the specified state in a manner that is safe for use in graph capture.
This method is crucial for ensuring that the generator's state can be captured in the CUDA graph.
Arguments:
state (torch.Generator): A Generator point to the new state for the generator, typically obtained from `graphsafe_get_state`.
Example:
>>> g_cuda = torch.Generator(device='cuda')
>>> g_cuda_other = torch.Generator(device='cuda')
>>> current_state = g_cuda_other.graphsafe_get_state()
>>> g_cuda.graphsafe_set_state(current_state)
""",
)
add_docstr(
torch.Generator.graphsafe_get_state,
r"""
Generator.graphsafe_get_state() -> torch.Generator
Retrieves the current state of the generator in a manner that is safe for graph capture.
This method is crucial for ensuring that the generator's state can be captured in the CUDA graph.
Returns:
torch.Generator: A Generator point to the current state of the generator
Example:
>>> g_cuda = torch.Generator(device='cuda')
>>> current_state = g_cuda.graphsafe_get_state()
""",
)
add_docstr(
torch.Generator.clone_state,
r"""
Generator.clone_state() -> torch.Generator
Clones the current state of the generator and returns a new generator pointing to this cloned state.
This method is beneficial for preserving a particular state of a generator to restore at a later point.
Returns:
torch.Generator: A Generator pointing to the newly cloned state.
Example:
>>> g_cuda = torch.Generator(device='cuda')
>>> cloned_state = g_cuda.clone_state()
""",
)
add_docstr(
torch.Generator.manual_seed,