Add stable Tensor get_device_index, use more stable DeviceIndex (#160143)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160143
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Jane Xu
2025-08-12 13:52:59 -07:00
committed by PyTorch MergeBot
parent 41673110cd
commit 355462e127
2 changed files with 31 additions and 5 deletions

View File

@ -36,6 +36,11 @@ Tensor sgd_out_of_place(
const bool maximize) {
STD_TORCH_CHECK(param.dim() == 1, "param must be 1D");
// these test the get_device() and get_device_index() methods
// while ascertaining that we are still on CPU
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);

View File

@ -1,13 +1,15 @@
#pragma once
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
namespace torch::stable {
using DeviceIndex =
int8_t; // this is from c10/core/Device.h and can be header only
// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we
// can converge on in this world as DeviceIndex in libtorch is not stable.
using DeviceIndex = int32_t;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom
@ -103,11 +105,30 @@ class Tensor {
return stride;
}
DeviceIndex get_device() const {
// This is almost the same API as the one in TensorBase.h, except
// we add a check that the returned device_index is within the
// range of int8_t.
int8_t get_device() const {
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(ath_.get(), &device_index));
return static_cast<DeviceIndex>(device_index);
STD_TORCH_CHECK(
device_index >= std::numeric_limits<int8_t>::min() &&
device_index <= std::numeric_limits<int8_t>::max(),
"Device index is out of range of return type int8_t, please use get_device_index() instead.");
return static_cast<int8_t>(device_index);
}
// The same as get_device but with two differences:
// 1. it has a more suiting name
// 2. it returns a DeviceIndex, which is int32_t in this world
// that should be more stable than the likely shifting
// DeviceIndex in libtorch (it is int8_t that might become int16_t)
DeviceIndex get_device_index() const {
int32_t device_index;
TORCH_ERROR_CODE_CHECK(
aoti_torch_get_device_index(ath_.get(), &device_index));
return device_index;
}
bool is_cuda() const {