mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 18:14:54 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -342,7 +342,7 @@ uint64_t get_any_data_ptr(Tensor t, bool mutable_) {
|
||||
|
||||
uint64_t get_template_any_data_ptr(Tensor t, c10::ScalarType dtype, bool mutable_) {
|
||||
#define DEFINE_CASE(T, name) \
|
||||
case torch::headeronly::ScalarType::name: { \
|
||||
case torch::headeronly::ScalarType::name: { \
|
||||
if (mutable_) { \
|
||||
return reinterpret_cast<uint64_t>(t.mutable_data_ptr<T>()); \
|
||||
} else { \
|
||||
@ -352,7 +352,6 @@ uint64_t get_template_any_data_ptr(Tensor t, c10::ScalarType dtype, bool mutable
|
||||
switch (dtype) {
|
||||
// per aten/src/ATen/templates/TensorMethods.cpp:
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE)
|
||||
AT_FORALL_QINT_TYPES(DEFINE_CASE)
|
||||
DEFINE_CASE(uint16_t, UInt16)
|
||||
DEFINE_CASE(uint32_t, UInt32)
|
||||
DEFINE_CASE(uint64_t, UInt64)
|
||||
|
||||
Reference in New Issue
Block a user