[c10d] Add hashing as a debug feature for before and after NCCL collective call (#113238)

For now, we use `TORCH_DISTRIBUTED_DEBUG = DETAIL` to turn a debug feature which calculate the hashing for input tensors and output results of c10d collective in NCCL.  This is a debugging feature so that we can rule out the bug from c10d level.

<img width="840" alt="image" src="https://github.com/pytorch/pytorch/assets/6937752/cdc70b0b-ae3c-4efd-86ff-adc5c5ba505f">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113238
Approved by: https://github.com/wconstab, https://github.com/fegin
This commit is contained in:
fduwjj
2023-12-22 05:43:59 +00:00
committed by PyTorch MergeBot
parent 039fbeb016
commit f6dfbffb3b
9 changed files with 103 additions and 19 deletions

View File

@ -25,6 +25,7 @@ def _register_builtin_comm_hook(
comm_hook_type: BuiltinCommHookType,
): ...
def _set_global_rank(rank: int) -> None: ...
def _hash_tensors(tensors: List[Tensor]) -> int: ...
class GradBucket:
def index(self) -> int: ...

View File

@ -4,7 +4,9 @@
#include <c10/util/env.h>
#ifdef USE_C10D_NCCL
#include <vector>
#include <cuda_runtime.h>
#include <mutex>
namespace c10d {
@ -60,6 +62,31 @@ std::string getNcclVersion() {
return versionString;
}
#ifdef USE_C10D_NCCL
size_t hashTensors(const std::vector<at::Tensor>& tensors) {
size_t hash = 0;
for (auto& tensor : tensors) {
if (tensor.numel() > 0 && tensor.storage()) {
size_t data_size = tensor.storage().nbytes();
if (data_size > 0 && tensor.storage().data_ptr()) {
auto src = static_cast<const char*>(tensor.storage().data_ptr().get());
char* dst = (char*)std::calloc(data_size, sizeof(char));
// This is needed so that we trigger a device synchronization so we can
// get the collective finished if launched on GPU and hash its output.
cudaMemcpy(dst, src, data_size, cudaMemcpyDeviceToHost);
for (size_t i = 0; i < data_size; ++i) {
// Update the hash for each byte in the tensor
hash = c10::hash_combine(
hash, c10::get_hash(((char*)dst)[i], data_size));
}
free(dst);
}
}
}
return hash;
}
#endif
bool nccl_use_nonblocking() {
static bool nccl_use_nonblocking_ =
c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true;

View File

@ -8,6 +8,7 @@
#include <memory>
#include <mutex>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <nccl.h>
@ -163,6 +164,7 @@
namespace c10d {
TORCH_API size_t hashTensors(const std::vector<at::Tensor>& tensors);
std::string getNcclVersion();
std::string ncclGetErrorWithVersion(ncclResult_t error);
bool nccl_use_nonblocking();

View File

@ -390,12 +390,14 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(
const char* profilingTitle,
const c10::optional<std::vector<at::Tensor>>& inputs,
bool desyncDebug,
bool enableTiming)
bool enableTiming,
DebugLevel distDebugLevel)
: Work(rank, opType, profilingTitle, inputs),
devices_(devices),
workStartTime_(std::chrono::steady_clock::now()),
seq_(seq),
timingEnabled_(enableTiming) {
timingEnabled_(enableTiming),
distDebugLevel_(distDebugLevel) {
// Creates the CUDA event wrappers
// Note: The actual events are lazily created when first recorded to with
// DEFAULT_FLAGS = cudaEventDisableTiming.
@ -431,7 +433,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w)
numelOut_(w.numelOut_),
store_(w.store_),
timingEnabled_(w.timingEnabled_),
trace_id_(w.trace_id_) {
trace_id_(w.trace_id_),
distDebugLevel_(w.distDebugLevel_) {
exception_ = w.exception_;
}
@ -648,6 +651,12 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
static_cast<int>(devices_.size())); // worldSize
synchronizeInternal(timeout);
// Always return true, because abort API is not implemented.
if (distDebugLevel_ >= DebugLevel::Detail) {
auto numel = getTensorsNumel(*outputs_);
auto hashValue = hashTensors(*outputs_);
PRINT_COLLECTIVE_HASH_SIGNATURE(
"output", opTypeToString(opType_), numel, hashValue);
}
return true;
}
@ -730,6 +739,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
heartbeatTimeoutInSec_ =
getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 10 /*10 Mins*/);
ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0);
enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
#ifdef ENABLE_NCCL_ERROR_CHECKING
enableTiming_.store(
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
@ -2218,7 +2228,8 @@ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
profilingTitle != nullptr ? c10::optional<std::vector<at::Tensor>>(inputs)
: c10::nullopt,
desyncDebug_,
enableTiming_.load());
enableTiming_.load(),
dist_debug_level_);
r->trace_id_ = NCCLTraceBuffer::get()->record(
uid_,
seq_,
@ -2421,6 +2432,13 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::collective(
}
}
if (enableCollecticeHashDebug_.load()) {
auto numel = getTensorsNumel(inputs);
auto hashValue = hashTensors(inputs);
PRINT_COLLECTIVE_HASH_SIGNATURE(
"input", opTypeToString(opType), numel, hashValue);
}
{
torch::cuda::nccl::AutoNcclGroup nccl_group_guard(
comms_, nccl_use_nonblocking());

View File

@ -95,6 +95,10 @@ enum ErrorHandlingMode {
#define SHOULD_TEAR_DOWN(a) (a != NoHandling && a != CleanUpOnly)
#define PRINT_COLLECTIVE_HASH_SIGNATURE(phase, opType, numel, hashValue) \
LOG(WARNING) << logPrefix() << "Hash of " << phase << " to NCCL " << opType \
<< " with size " << numel << " is " << hashValue;
// If set, ProcessGroupNCCL doesn't use recordStream calls to ensure
// caching allocator safety for tensors used on both user-facing and
// internal comm streams.
@ -161,7 +165,8 @@ class TORCH_API ProcessGroupNCCL : public Backend {
const char* profilingTitle = nullptr,
const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt,
bool desyncDebug = false,
bool enableTiming = false);
bool enableTiming = false,
DebugLevel distDebugLevel = DebugLevel::Off);
// Copy constructor doing partial copy without outputs_. Cleanup thread
// monitors and removes finished works. However it will deadlock when
// destructs outputs_ tensors who are view tensors in autograd graph.
@ -313,6 +318,7 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// unique id used to tell the trace buffer that this
// work has completed
c10::optional<uint64_t> trace_id_;
DebugLevel distDebugLevel_;
friend class ProcessGroupNCCL;
};
@ -930,6 +936,10 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// is set to true.
std::atomic<bool> enableTiming_;
// Flag to enable the print of hash value of input/output of collectives for
// verification.
std::atomic<bool> enableCollecticeHashDebug_;
// Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set
bool avoidRecordStreams_ = false;

View File

@ -22,4 +22,12 @@ std::vector<at::Tensor> getTensorShapes(
return shapeTensors;
}
size_t getTensorsNumel(const std::vector<at::Tensor>& tensors) {
size_t numel = 0;
for (auto& tensor : tensors) {
numel += tensor.numel();
}
return numel;
}
} // namespace c10d

View File

@ -33,6 +33,8 @@ typedef SSIZE_T ssize_t;
namespace c10d {
TORCH_API size_t getTensorsNumel(const std::vector<at::Tensor>& tensors);
// Retrieve tensor shapes from a given tensor.
TORCH_API std::vector<at::Tensor> getTensorShapes(
const std::vector<at::Tensor>& tensors);

View File

@ -2792,6 +2792,16 @@ such as `dist.all_reduce(tensor, async_op=True)`.
)");
#ifdef USE_C10D_NCCL
module.def(
"_hash_tensors",
[](const std::vector<at::Tensor>& tensors) {
return ::c10d::hashTensors(tensors);
},
py::arg("tensors"),
R"(
Arguments:
tensors(List[torch.Tensor]): List of tensors we want to hash.
)");
module.def("_dump_nccl_trace", []() {
return py::bytes(::c10d::dump_nccl_trace());
});

View File

@ -2201,7 +2201,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False):
else:
work.wait()
def _object_to_tensor(obj, device):
def _object_to_tensor(obj, device, group):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
@ -2209,11 +2209,21 @@ def _object_to_tensor(obj, device):
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage).to(device)
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([byte_tensor])
logger.warning(f"_object_to_tensor size: {byte_tensor.numel()} hash value: {hash}") # noqa: G004
local_size = torch.LongTensor([byte_tensor.numel()]).to(device)
return byte_tensor, local_size
def _tensor_to_object(tensor, tensor_size):
def _tensor_to_object(tensor, tensor_size, group):
if get_debug_level() == DebugLevel.DETAIL and is_nccl_available():
backend = get_backend(group)
if backend == Backend.NCCL:
hash = torch._C._distributed_c10d._hash_tensors([tensor])
logger.warning(f"_tensor_to_object size: {tensor.numel()} hash value: {hash}") # noqa: G004
tensor = tensor.cpu()
buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()
@ -2278,7 +2288,7 @@ def all_gather_object(object_list, obj, group=None):
return
current_device = _get_pg_default_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)
input_tensor, local_size = _object_to_tensor(obj, current_device, group)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
@ -2306,10 +2316,8 @@ def all_gather_object(object_list, obj, group=None):
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)
object_list[i] = _tensor_to_object(tensor, tensor_size, group)
@_exception_logger
@ -2380,7 +2388,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
my_rank = get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list)
current_device = _get_pg_default_device(group)
input_tensor, local_size = _object_to_tensor(obj, current_device)
input_tensor, local_size = _object_to_tensor(obj, current_device, group)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
@ -2420,7 +2428,7 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
object_gather_list[i] = _tensor_to_object(tensor, tensor_size, group)
@_exception_logger
@ -2498,7 +2506,7 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device) for obj in object_list])
tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device)
@ -2528,10 +2536,8 @@ def broadcast_object_list(object_list, src=0, group=None, device=None):
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)
object_list[i] = _tensor_to_object(obj_view, obj_size, group)
@_exception_logger
@ -2608,7 +2614,7 @@ def scatter_object_list(
pg_device = _get_pg_default_device(group)
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[_object_to_tensor(obj, pg_device) for obj in scatter_object_input_list]
*[_object_to_tensor(obj, pg_device, group) for obj in scatter_object_input_list]
)
tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes)
@ -2641,7 +2647,7 @@ def scatter_object_list(
)
# Deserialize back to object
scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size)
scatter_object_output_list[0] = _tensor_to_object(output_tensor, obj_tensor_size, group)
@_exception_logger