Files
pytorch/torch/csrc/stable/accelerator.h
Mikayla Gawarecki fab8455943 Don't use declarations in global namespace in stable headers (#163352)
Fixes https://github.com/pytorch/pytorch/issues/163338

Configured https://clang.llvm.org/extra/clang-tidy/checks/google/global-names-in-headers.html for torch/csrc/stable

Note that doesn't error for the DeleterFnPtr case, but will generate the following for the `using torch::stable::Tensor;`

```
>>> Lint for torch/csrc/stable/ops.h:

  Error (CLANGTIDY) [google-global-names-in-headers,-warnings-as-errors]
    using declarations in the global namespace in headers are prohibited

         10  |#include <torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h>
         11  |#include <torch/headeronly/core/ScalarType.h>
         12  |
    >>>  13  |using torch::stable::Tensor;
         14  |
         15  |namespace torch::stable {
         16  |
   ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163352
Approved by: https://github.com/janeyx99
2025-09-19 21:15:52 +00:00

79 lines
2.1 KiB
C++

#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/util/shim_utils.h>
#include <memory>
namespace torch::stable::accelerator {
using DeleterFnPtr = void (*)(void*);
namespace {
inline void delete_device_guard(void* ptr) {
TORCH_ERROR_CODE_CHECK(
aoti_torch_delete_device_guard(reinterpret_cast<DeviceGuardHandle>(ptr)));
}
} // namespace
// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we
// can converge on in this world as DeviceIndex in libtorch is not stable.
using DeviceIndex = int32_t;
using StreamId = int64_t; // this is from c10/core/Stream.h
class DeviceGuard {
public:
explicit DeviceGuard() = delete;
explicit DeviceGuard(DeviceIndex device_index)
: guard_(nullptr, delete_device_guard) {
DeviceGuardHandle ptr = nullptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_create_device_guard(device_index, &ptr));
guard_.reset(ptr);
}
void set_index(DeviceIndex device_index) {
TORCH_ERROR_CODE_CHECK(
aoti_torch_device_guard_set_index(guard_.get(), device_index));
}
private:
std::unique_ptr<DeviceGuardOpaque, DeleterFnPtr> guard_;
};
class Stream {
public:
explicit Stream() = delete;
// Construct a stable::Stream from a StreamHandle
// Steals ownership from the StreamHandle
explicit Stream(StreamHandle stream)
: stream_(stream, [](StreamHandle stream) {
TORCH_ERROR_CODE_CHECK(aoti_torch_delete_stream(stream));
}) {}
StreamId id() const {
StreamId stream_id;
TORCH_ERROR_CODE_CHECK(aoti_torch_stream_id(stream_.get(), &stream_id));
return stream_id;
}
private:
std::shared_ptr<StreamOpaque> stream_;
};
inline Stream getCurrentStream(DeviceIndex device_index) {
StreamHandle stream = nullptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_stream(device_index, &stream));
return Stream(stream);
}
// Get the current device index
inline DeviceIndex getCurrentDeviceIndex() {
DeviceIndex device_index;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_device_index(&device_index));
return device_index;
}
} // namespace torch::stable::accelerator