mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129754 Approved by: https://github.com/ezyang
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
import torch
|
|
|
|
|
|
RPC_SPARSE = "rpc_sparse"
|
|
RPC_DENSE = "rpc_dense"
|
|
|
|
|
|
def sparse_tensor_to_rpc_format(sparse_tensor):
|
|
r"""
|
|
A helper function creates a list containing the indices, values, and size
|
|
of a coalesced sparse tensor.
|
|
Args:
|
|
sparse_tensor (torch.Tensor): sparse_coo_tensor represented as a list
|
|
"""
|
|
sparse_tensor = sparse_tensor.coalesce()
|
|
return [sparse_tensor.indices(), sparse_tensor.values(), sparse_tensor.size()]
|
|
|
|
|
|
def sparse_rpc_format_to_tensor(sparse_rpc_format):
|
|
r"""
|
|
A helper function creates a sparse_coo_tensor from indices, values, and size.
|
|
Args:
|
|
sparse_rpc_format (list): sparse_coo_tensor represented as a list
|
|
"""
|
|
return torch.sparse_coo_tensor(
|
|
sparse_rpc_format[0], sparse_rpc_format[1], sparse_rpc_format[2]
|
|
).coalesce()
|
|
|
|
|
|
def process_bucket_with_remote_server(state, bucket):
|
|
r"""
|
|
Processes a gradient bucket passed by a DDP communication hook
|
|
during .backward(). The method supports processing sparse and dense
|
|
tensors. It records RPC future completion time metric for the trainer.
|
|
Args:
|
|
state (object): maintains state during the training process
|
|
bucket (GradBucket): gradient bucket
|
|
"""
|
|
cref = state.cref
|
|
tensor = bucket.buffer()
|
|
if not cref.use_cuda_rpc:
|
|
tensor = tensor.cpu()
|
|
sparse = tensor.is_sparse
|
|
if sparse:
|
|
tensor = sparse_tensor_to_rpc_format(tensor)
|
|
b_index = bucket.get_index()
|
|
server_args = [cref.server_rref, state.batch_number, b_index, tensor]
|
|
key = state.get_key(b_index)
|
|
cref.record_start("hook_future_metric", key, RPC_SPARSE if sparse else RPC_DENSE)
|
|
fut = cref.server_rref.rpc_async().average_gradient(*server_args)
|
|
|
|
def callback(fut):
|
|
cref.record_end("hook_future_metric", key)
|
|
tensor = fut.wait()
|
|
if type(tensor) is list:
|
|
tensor = sparse_rpc_format_to_tensor(tensor)
|
|
tensor = tensor.cuda(cref.rank)
|
|
return [tensor]
|
|
|
|
return fut.then(callback)
|