mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
A followup on PR #163282 Fixes #163274 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164161 Approved by: https://github.com/ngimel, https://github.com/eqy
499 lines
15 KiB
C++
499 lines
15 KiB
C++
#include <ATen/DLConvertor.h>
|
|
#include <ATen/Functions.h>
|
|
|
|
using namespace std;
|
|
namespace at {
|
|
|
|
DLDataType getDLDataType(const Tensor& t) {
|
|
DLDataType dtype;
|
|
dtype.lanes = 1;
|
|
dtype.bits = t.element_size() * 8;
|
|
switch (t.scalar_type()) {
|
|
case ScalarType::UInt1:
|
|
case ScalarType::UInt2:
|
|
case ScalarType::UInt3:
|
|
case ScalarType::UInt4:
|
|
case ScalarType::UInt5:
|
|
case ScalarType::UInt6:
|
|
case ScalarType::UInt7:
|
|
case ScalarType::Byte:
|
|
case ScalarType::UInt16:
|
|
case ScalarType::UInt32:
|
|
case ScalarType::UInt64:
|
|
dtype.code = DLDataTypeCode::kDLUInt;
|
|
break;
|
|
case ScalarType::Int1:
|
|
case ScalarType::Int2:
|
|
case ScalarType::Int3:
|
|
case ScalarType::Int4:
|
|
case ScalarType::Int5:
|
|
case ScalarType::Int6:
|
|
case ScalarType::Int7:
|
|
case ScalarType::Char:
|
|
dtype.code = DLDataTypeCode::kDLInt;
|
|
break;
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
case ScalarType::Double:
|
|
dtype.code = DLDataTypeCode::kDLFloat;
|
|
break;
|
|
case ScalarType::Float:
|
|
dtype.code = DLDataTypeCode::kDLFloat;
|
|
break;
|
|
// NOLINTNEXTLINE(bugprone-branch-clone)
|
|
case ScalarType::Int:
|
|
dtype.code = DLDataTypeCode::kDLInt;
|
|
break;
|
|
case ScalarType::Long:
|
|
dtype.code = DLDataTypeCode::kDLInt;
|
|
break;
|
|
case ScalarType::Short:
|
|
dtype.code = DLDataTypeCode::kDLInt;
|
|
break;
|
|
case ScalarType::Half:
|
|
dtype.code = DLDataTypeCode::kDLFloat;
|
|
break;
|
|
case ScalarType::Bool:
|
|
dtype.code = DLDataTypeCode::kDLBool;
|
|
break;
|
|
case ScalarType::ComplexHalf:
|
|
case ScalarType::ComplexFloat:
|
|
case ScalarType::ComplexDouble:
|
|
dtype.code = DLDataTypeCode::kDLComplex;
|
|
break;
|
|
case ScalarType::BFloat16:
|
|
dtype.code = DLDataTypeCode::kDLBfloat;
|
|
break;
|
|
// TODO(#146647): use macro here instead of spelling out each shell dtype
|
|
case ScalarType::Float8_e5m2:
|
|
dtype.code = DLDataTypeCode::kDLFloat8_e5m2;
|
|
break;
|
|
case ScalarType::Float8_e5m2fnuz:
|
|
dtype.code = DLDataTypeCode::kDLFloat8_e5m2fnuz;
|
|
break;
|
|
case ScalarType::Float8_e4m3fn:
|
|
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fn;
|
|
break;
|
|
case ScalarType::Float8_e4m3fnuz:
|
|
dtype.code = DLDataTypeCode::kDLFloat8_e4m3fnuz;
|
|
break;
|
|
case ScalarType::Float8_e8m0fnu:
|
|
dtype.code = DLDataTypeCode::kDLFloat8_e8m0fnu;
|
|
break;
|
|
case ScalarType::Float4_e2m1fn_x2:
|
|
dtype.code = DLDataTypeCode::kDLFloat4_e2m1fn;
|
|
dtype.lanes = 2;
|
|
dtype.bits = 4;
|
|
break;
|
|
case ScalarType::QInt8:
|
|
case ScalarType::QUInt8:
|
|
case ScalarType::QInt32:
|
|
case ScalarType::QUInt4x2:
|
|
case ScalarType::QUInt2x4:
|
|
TORCH_CHECK_BUFFER(false, "QUInt/QInt types are not supported by dlpack");
|
|
break;
|
|
case ScalarType::Bits1x8:
|
|
case ScalarType::Bits2x4:
|
|
case ScalarType::Bits4x2:
|
|
case ScalarType::Bits8:
|
|
case ScalarType::Bits16:
|
|
TORCH_CHECK_BUFFER(false, "Bit types are not supported by dlpack");
|
|
break;
|
|
case ScalarType::Undefined:
|
|
TORCH_CHECK_BUFFER(false, "Undefined is not a valid ScalarType");
|
|
case ScalarType::NumOptions:
|
|
TORCH_CHECK_BUFFER(false, "NumOptions is not a valid ScalarType");
|
|
}
|
|
return dtype;
|
|
}
|
|
|
|
DLDevice torchDeviceToDLDevice(at::Device device) {
|
|
DLDevice ctx;
|
|
|
|
ctx.device_id = (device.is_cuda() || device.is_privateuseone())
|
|
? static_cast<int32_t>(static_cast<unsigned char>(device.index()))
|
|
: 0;
|
|
|
|
switch (device.type()) {
|
|
case DeviceType::CPU:
|
|
ctx.device_type = DLDeviceType::kDLCPU;
|
|
break;
|
|
case DeviceType::CUDA:
|
|
#ifdef USE_ROCM
|
|
// ROCM, if enabled will look like cuda to PyTorch
|
|
// while everyone else should see HIP
|
|
ctx.device_type = DLDeviceType::kDLROCM;
|
|
#else
|
|
ctx.device_type = DLDeviceType::kDLCUDA;
|
|
#endif
|
|
break;
|
|
case DeviceType::OPENCL:
|
|
ctx.device_type = DLDeviceType::kDLOpenCL;
|
|
break;
|
|
case DeviceType::HIP:
|
|
ctx.device_type = DLDeviceType::kDLROCM;
|
|
break;
|
|
case DeviceType::XPU:
|
|
ctx.device_type = DLDeviceType::kDLOneAPI;
|
|
ctx.device_id = at::detail::getXPUHooks().getGlobalIdxFromDevice(device);
|
|
break;
|
|
case DeviceType::MAIA:
|
|
ctx.device_type = DLDeviceType::kDLMAIA;
|
|
break;
|
|
case DeviceType::PrivateUse1:
|
|
ctx.device_type = DLDeviceType::kDLExtDev;
|
|
break;
|
|
case DeviceType::MPS:
|
|
ctx.device_type = DLDeviceType::kDLMetal;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(false, "Cannot pack tensors on " + device.str());
|
|
}
|
|
|
|
return ctx;
|
|
}
|
|
|
|
static Device getATenDevice(DLDeviceType type, c10::DeviceIndex index, void* data = nullptr) {
|
|
switch (type) {
|
|
case DLDeviceType::kDLCPU:
|
|
return at::Device(DeviceType::CPU);
|
|
#ifndef USE_ROCM
|
|
// if we are compiled under HIP, we cannot do cuda
|
|
case DLDeviceType::kDLCUDA:
|
|
return at::Device(DeviceType::CUDA, index);
|
|
#endif
|
|
case DLDeviceType::kDLOpenCL:
|
|
return at::Device(DeviceType::OPENCL, index);
|
|
case DLDeviceType::kDLROCM:
|
|
#ifdef USE_ROCM
|
|
// this looks funny, we need to return CUDA here to masquerade
|
|
return at::Device(DeviceType::CUDA, index);
|
|
#else
|
|
return at::Device(DeviceType::HIP, index);
|
|
#endif
|
|
case DLDeviceType::kDLOneAPI:
|
|
TORCH_CHECK(data != nullptr, "Can't get ATen device for XPU without XPU data.");
|
|
return at::detail::getXPUHooks().getDeviceFromPtr(data);
|
|
case DLDeviceType::kDLMAIA:
|
|
return at::Device(DeviceType::MAIA, index);
|
|
case DLDeviceType::kDLExtDev:
|
|
return at::Device(DeviceType::PrivateUse1, index);
|
|
case DLDeviceType::kDLMetal:
|
|
return at::Device(DeviceType::MPS, index);
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported device_type: ", std::to_string(type));
|
|
}
|
|
}
|
|
|
|
ScalarType toScalarType(const DLDataType& dtype) {
|
|
ScalarType stype = ScalarType::Undefined;
|
|
if (dtype.code != DLDataTypeCode::kDLFloat4_e2m1fn) {
|
|
TORCH_CHECK_BUFFER(
|
|
dtype.lanes == 1,
|
|
"ATen does not support lanes != 1 for dtype code", std::to_string(dtype.code));
|
|
}
|
|
switch (dtype.code) {
|
|
case DLDataTypeCode::kDLUInt:
|
|
switch (dtype.bits) {
|
|
case 8:
|
|
stype = ScalarType::Byte;
|
|
break;
|
|
case 16:
|
|
stype = ScalarType::UInt16;
|
|
break;
|
|
case 32:
|
|
stype = ScalarType::UInt32;
|
|
break;
|
|
case 64:
|
|
stype = ScalarType::UInt64;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kUInt bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLInt:
|
|
switch (dtype.bits) {
|
|
case 8:
|
|
stype = ScalarType::Char;
|
|
break;
|
|
case 16:
|
|
stype = ScalarType::Short;
|
|
break;
|
|
case 32:
|
|
stype = ScalarType::Int;
|
|
break;
|
|
case 64:
|
|
stype = ScalarType::Long;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kInt bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLFloat:
|
|
switch (dtype.bits) {
|
|
case 16:
|
|
stype = ScalarType::Half;
|
|
break;
|
|
case 32:
|
|
stype = ScalarType::Float;
|
|
break;
|
|
case 64:
|
|
stype = ScalarType::Double;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLBfloat:
|
|
switch (dtype.bits) {
|
|
case 16:
|
|
stype = ScalarType::BFloat16;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLComplex:
|
|
switch (dtype.bits) {
|
|
case 32:
|
|
stype = ScalarType::ComplexHalf;
|
|
break;
|
|
case 64:
|
|
stype = ScalarType::ComplexFloat;
|
|
break;
|
|
case 128:
|
|
stype = ScalarType::ComplexDouble;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kFloat bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLBool:
|
|
switch (dtype.bits) {
|
|
case 8:
|
|
stype = ScalarType::Bool;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kDLBool bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLFloat8_e5m2:
|
|
switch (dtype.bits) {
|
|
case 8:
|
|
stype = ScalarType::Float8_e5m2;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kDLFloat8_e5m2 bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLFloat8_e5m2fnuz:
|
|
switch (dtype.bits) {
|
|
case 8:
|
|
stype = ScalarType::Float8_e5m2fnuz;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kDLFloat8_e5m2fnuz bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLFloat8_e4m3fn:
|
|
switch (dtype.bits) {
|
|
case 8:
|
|
stype = ScalarType::Float8_e4m3fn;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kDLFloat8_e4m3fn bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLFloat8_e4m3fnuz:
|
|
switch (dtype.bits) {
|
|
case 8:
|
|
stype = ScalarType::Float8_e4m3fnuz;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kDLFloat8_e4m3fnuz bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLFloat8_e8m0fnu:
|
|
switch (dtype.bits) {
|
|
case 8:
|
|
stype = ScalarType::Float8_e8m0fnu;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kDLFloat8_e8m0fnu bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
case DLDataTypeCode::kDLFloat4_e2m1fn:
|
|
switch (dtype.bits) {
|
|
case 4:
|
|
switch (dtype.lanes) {
|
|
case 2:
|
|
stype = ScalarType::Float4_e2m1fn_x2;
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kDLFloat4_e2m1fn lanes ", std::to_string(dtype.lanes));
|
|
}
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(
|
|
false, "Unsupported kDLFloat4_e2m1fn bits ", std::to_string(dtype.bits));
|
|
}
|
|
break;
|
|
default:
|
|
TORCH_CHECK_BUFFER(false, "Unsupported code ", std::to_string(dtype.code));
|
|
}
|
|
return stype;
|
|
}
|
|
|
|
namespace {
|
|
|
|
// The templated classes below are needed for supporting both:
|
|
// - DLManagedTensor
|
|
// - DLManagedTensorVersioned
|
|
template <class T>
|
|
struct ATenDLMTensor {
|
|
Tensor handle;
|
|
T tensor{};
|
|
};
|
|
|
|
template <class T>
|
|
void deleter(T* arg) {
|
|
delete static_cast<ATenDLMTensor<T>*>(arg->manager_ctx);
|
|
}
|
|
|
|
// Adds version information for DLManagedTensorVersioned.
|
|
// This is a no-op for the other types.
|
|
template <class T>
|
|
void fillVersion(T* tensor) {}
|
|
|
|
template <>
|
|
void fillVersion<DLManagedTensorVersioned>(
|
|
DLManagedTensorVersioned* tensor) {
|
|
tensor->flags = 0;
|
|
tensor->version.major = DLPACK_MAJOR_VERSION;
|
|
tensor->version.minor = DLPACK_MINOR_VERSION;
|
|
}
|
|
|
|
// This function returns a shared_ptr to memory managed DLpack tensor
|
|
// constructed out of ATen tensor
|
|
template <class T>
|
|
T* toDLPackImpl(const Tensor& src) {
|
|
ATenDLMTensor<T>* atDLMTensor(new ATenDLMTensor<T>);
|
|
atDLMTensor->handle = src;
|
|
atDLMTensor->tensor.manager_ctx = atDLMTensor;
|
|
atDLMTensor->tensor.deleter = &deleter<T>;
|
|
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
|
|
atDLMTensor->tensor.dl_tensor.device = torchDeviceToDLDevice(src.device());
|
|
atDLMTensor->tensor.dl_tensor.ndim = static_cast<int32_t>(src.dim());
|
|
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
|
|
atDLMTensor->tensor.dl_tensor.shape = const_cast<int64_t*>(src.sizes().data());
|
|
atDLMTensor->tensor.dl_tensor.strides = const_cast<int64_t*>(src.strides().data());
|
|
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
|
|
fillVersion(&atDLMTensor->tensor);
|
|
|
|
return &(atDLMTensor->tensor);
|
|
}
|
|
|
|
// Explicitly instantiate the template above for both classes.
|
|
template DLManagedTensor* toDLPackImpl<DLManagedTensor>(const Tensor&);
|
|
template DLManagedTensorVersioned* toDLPackImpl<DLManagedTensorVersioned>(const Tensor&);
|
|
|
|
// This function constructs a Tensor from a memory managed DLPack which
|
|
// may be represented as either: DLManagedTensor and DLManagedTensorVersioned.
|
|
template <class T>
|
|
at::Tensor fromDLPackImpl(T* src, std::function<void(void*)> deleter) {
|
|
if (!deleter) {
|
|
deleter = [src](void* self [[maybe_unused]]) {
|
|
if (src->deleter) {
|
|
src->deleter(src);
|
|
}
|
|
};
|
|
}
|
|
|
|
DLTensor& dl_tensor = src->dl_tensor;
|
|
Device device = getATenDevice(dl_tensor.device.device_type, dl_tensor.device.device_id, dl_tensor.data);
|
|
ScalarType stype = toScalarType(dl_tensor.dtype);
|
|
|
|
if (!dl_tensor.strides) {
|
|
return at::from_blob(
|
|
dl_tensor.data,
|
|
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
|
|
std::move(deleter),
|
|
at::device(device).dtype(stype),
|
|
{device});
|
|
}
|
|
return at::from_blob(
|
|
dl_tensor.data,
|
|
IntArrayRef(dl_tensor.shape, dl_tensor.ndim),
|
|
IntArrayRef(dl_tensor.strides, dl_tensor.ndim),
|
|
deleter,
|
|
at::device(device).dtype(stype),
|
|
{device});
|
|
}
|
|
|
|
// Explicitly instantiate the template above for both classes.
|
|
template at::Tensor fromDLPackImpl<DLManagedTensor>(DLManagedTensor* src, std::function<void(void*)> deleter);
|
|
template at::Tensor fromDLPackImpl<DLManagedTensorVersioned>(DLManagedTensorVersioned* src, std::function<void(void*)> deleter);
|
|
|
|
} // namespace
|
|
|
|
DLManagedTensor* toDLPack(const Tensor& src) {
|
|
return toDLPackImpl<DLManagedTensor>(src);
|
|
}
|
|
|
|
DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src) {
|
|
return toDLPackImpl<DLManagedTensorVersioned>(src);
|
|
}
|
|
|
|
Tensor fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter) {
|
|
return fromDLPackImpl<DLManagedTensor>(src, std::move(deleter));
|
|
}
|
|
|
|
Tensor fromDLPackVersioned(DLManagedTensorVersioned* src, std::function<void(void*)> deleter) {
|
|
return fromDLPackImpl<DLManagedTensorVersioned>(src, std::move(deleter));
|
|
}
|
|
|
|
Tensor maybeCopyTensor(
|
|
const Tensor& data,
|
|
std::optional<DLDevice> optional_dl_device,
|
|
std::optional<bool> copy) {
|
|
bool force_copy = copy.has_value() && *copy;
|
|
bool force_move = copy.has_value() && !*copy;
|
|
|
|
if (optional_dl_device.has_value()) {
|
|
auto device = at::getATenDevice(
|
|
optional_dl_device->device_type,
|
|
static_cast<c10::DeviceIndex>(optional_dl_device->device_id));
|
|
|
|
if (device != data.device()) {
|
|
TORCH_CHECK_VALUE(
|
|
!force_move,
|
|
"cannot move (i.e. copy=False) tensor from ",
|
|
data.device(),
|
|
" to ",
|
|
device,
|
|
" without copying.");
|
|
return data.to(device);
|
|
}
|
|
}
|
|
|
|
if (force_copy) {
|
|
return data.clone();
|
|
}
|
|
|
|
return data;
|
|
}
|
|
|
|
} // namespace at
|