[BE]: Replace printf with fmtlib call (#154814)

Safer, faster, more concise, and better type checking. Also add a few misc changes in the file.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154814
Approved by: https://github.com/jansel
This commit is contained in:
Aaron Gokaslan
2025-06-01 22:27:08 +00:00
committed by PyTorch MergeBot
parent 206e9d5160
commit 2b2245d5db

View File

@ -3,6 +3,7 @@
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/driver_api.h>
#include <fmt/printf.h>
#include <cuda_runtime.h>
#include <nvml.h>
@ -12,18 +13,13 @@ namespace {
constexpr int max_nvlinks = 64;
std::string get_bus_id(int device_idx) {
// NOLINTNEXTLINE(*array*)
char bus_id[80];
cudaDeviceProp prop{};
C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_idx));
snprintf(
bus_id,
sizeof(bus_id),
return fmt::sprintf(
NVML_DEVICE_PCI_BUS_ID_FMT,
prop.pciDomainID,
prop.pciBusID,
prop.pciDeviceID);
return std::string(bus_id);
}
struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector {
@ -39,6 +35,7 @@ struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector {
// Obtain the bus_id for all visible devices
std::unordered_map<std::string, int> bus_id_to_device_idx;
bus_id_to_device_idx.reserve(num_devices);
std::vector<std::string> bus_ids;
bus_ids.reserve(num_devices);
for (int i = 0; i < num_devices; ++i) {
@ -47,7 +44,7 @@ struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector {
bus_ids.push_back(std::move(bus_id));
}
static const char* warning_msg =
static constexpr const char* warning_msg =
"PyTorch features that use NVLinkDetector may assume no NVLink presence.";
auto driver_api = c10::cuda::DriverAPI::get();