mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add stable Tensor get_device_index, use more stable DeviceIndex (#160143)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160143 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
41673110cd
commit
355462e127
@ -36,6 +36,11 @@ Tensor sgd_out_of_place(
|
||||
const bool maximize) {
|
||||
STD_TORCH_CHECK(param.dim() == 1, "param must be 1D");
|
||||
|
||||
// these test the get_device() and get_device_index() methods
|
||||
// while ascertaining that we are still on CPU
|
||||
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
|
||||
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
|
||||
|
||||
int64_t *param_sizes;
|
||||
int64_t *param_strides;
|
||||
aoti_torch_get_sizes(param.get(), ¶m_sizes);
|
||||
|
@ -1,13 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
#include <torch/headeronly/util/shim_utils.h>
|
||||
#include <climits>
|
||||
#include <memory>
|
||||
|
||||
namespace torch::stable {
|
||||
|
||||
using DeviceIndex =
|
||||
int8_t; // this is from c10/core/Device.h and can be header only
|
||||
// this is bigger than DeviceIndex in c10/core/Device.h but it is the type we
|
||||
// can converge on in this world as DeviceIndex in libtorch is not stable.
|
||||
using DeviceIndex = int32_t;
|
||||
|
||||
// The torch::stable::Tensor class is a highlevel C++ wrapper around
|
||||
// the C shim Tensor APIs. We've modeled this class after TensorBase, as custom
|
||||
@ -103,11 +105,30 @@ class Tensor {
|
||||
return stride;
|
||||
}
|
||||
|
||||
DeviceIndex get_device() const {
|
||||
// This is almost the same API as the one in TensorBase.h, except
|
||||
// we add a check that the returned device_index is within the
|
||||
// range of int8_t.
|
||||
int8_t get_device() const {
|
||||
int32_t device_index;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_index(ath_.get(), &device_index));
|
||||
return static_cast<DeviceIndex>(device_index);
|
||||
STD_TORCH_CHECK(
|
||||
device_index >= std::numeric_limits<int8_t>::min() &&
|
||||
device_index <= std::numeric_limits<int8_t>::max(),
|
||||
"Device index is out of range of return type int8_t, please use get_device_index() instead.");
|
||||
return static_cast<int8_t>(device_index);
|
||||
}
|
||||
|
||||
// The same as get_device but with two differences:
|
||||
// 1. it has a more suiting name
|
||||
// 2. it returns a DeviceIndex, which is int32_t in this world
|
||||
// that should be more stable than the likely shifting
|
||||
// DeviceIndex in libtorch (it is int8_t that might become int16_t)
|
||||
DeviceIndex get_device_index() const {
|
||||
int32_t device_index;
|
||||
TORCH_ERROR_CODE_CHECK(
|
||||
aoti_torch_get_device_index(ath_.get(), &device_index));
|
||||
return device_index;
|
||||
}
|
||||
|
||||
bool is_cuda() const {
|
||||
|
Reference in New Issue
Block a user