Using scalarType instead string in function _group_tensors_by_device_and_dtype. (#127869)

Now torch.dtype can pass through pybind11, so modify function _group_tensors_by_device_and_dtype to using scalar type. And without convert torch.dtype and string in python and c++ side.
@ezyang @bdhirsh
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127869
Approved by: https://github.com/ezyang
This commit is contained in:
Shan19900305
2024-06-04 18:19:28 +00:00
committed by PyTorch MergeBot
parent 0ff60236ab
commit 3bcc3cddb5
3 changed files with 4 additions and 46 deletions

View File

@ -2154,50 +2154,13 @@ Call this whenever a new thread is created in order to propagate values from
return torch::should_allow_numbers_as_tensors(name);
});
// FIXME(crcrpar): Better to have `at::ScalarType` get mapped to `torch.dtype`
// Currently I see the second item of the key is displayed as
// e.g. `torch._C._te.ScalarType at 0x7fcf318adab0`
// I thought adding an appropriate type_caster of `at::ScalarType` to
// torch/csrc/pybind.h` would solve this but it caused segmentation fault in
// my environment.
using _DeviceDtypeKey = std::pair<at::Device, std::string>;
// Custom hasher is necessary to make unordered_map compilable for Windows
// debug targets. As `at::native::ParamsHash` only works on structs with
// standard layout, but std::string isn't one in Visual C++ debug builds,
// which one can easily verify by running something like:
// #define _DEBUG
// #include <type_traits>
// #include <string>
// static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
// If above condition is not met, VC++ raises a very cryptic compilation
// error. See
// https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for
// more detail
struct _DeviceDtypeHasher {
std::size_t operator()(const _DeviceDtypeKey& k) const noexcept {
static at::native::ParamsHash<at::Device> device_hasher;
static std::hash<std::string> string_hasher;
return device_hasher(k.first) ^ string_hasher(k.second);
}
};
using _FlatMap = std::unordered_map<
_DeviceDtypeKey,
at::native::TensorsAndIndicesT,
_DeviceDtypeHasher>;
py_module.def(
"_group_tensors_by_device_and_dtype",
[](const std::vector<std::vector<std::optional<at::Tensor>>>&
nested_tensorlist,
const bool with_indices) {
_FlatMap map;
for (const auto& iter :
at::native::_group_tensors_by_first_tensors_device_and_dtype(
nested_tensorlist, with_indices)) {
const auto scalar_type_name =
torch::utils::getDtypeNames(iter.first.second).first;
map.insert({{iter.first.first, scalar_type_name}, iter.second});
}
return map;
return at::native::_group_tensors_by_first_tensors_device_and_dtype(
nested_tensorlist, with_indices);
});
py_module.def(