mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Using scalarType instead string in function _group_tensors_by_device_and_dtype. (#127869)
Now torch.dtype can pass through pybind11, so modify function _group_tensors_by_device_and_dtype to using scalar type. And without convert torch.dtype and string in python and c++ side. @ezyang @bdhirsh Pull Request resolved: https://github.com/pytorch/pytorch/pull/127869 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
0ff60236ab
commit
3bcc3cddb5
@ -2154,50 +2154,13 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
return torch::should_allow_numbers_as_tensors(name);
|
||||
});
|
||||
|
||||
// FIXME(crcrpar): Better to have `at::ScalarType` get mapped to `torch.dtype`
|
||||
// Currently I see the second item of the key is displayed as
|
||||
// e.g. `torch._C._te.ScalarType at 0x7fcf318adab0`
|
||||
// I thought adding an appropriate type_caster of `at::ScalarType` to
|
||||
// torch/csrc/pybind.h` would solve this but it caused segmentation fault in
|
||||
// my environment.
|
||||
using _DeviceDtypeKey = std::pair<at::Device, std::string>;
|
||||
// Custom hasher is necessary to make unordered_map compilable for Windows
|
||||
// debug targets. As `at::native::ParamsHash` only works on structs with
|
||||
// standard layout, but std::string isn't one in Visual C++ debug builds,
|
||||
// which one can easily verify by running something like:
|
||||
// #define _DEBUG
|
||||
// #include <type_traits>
|
||||
// #include <string>
|
||||
// static_assert(std::is_standard_layout_v<std::string>, "Oh noes");
|
||||
// If above condition is not met, VC++ raises a very cryptic compilation
|
||||
// error. See
|
||||
// https://github.com/pytorch/pytorch/pull/100007#discussion_r1227116292 for
|
||||
// more detail
|
||||
struct _DeviceDtypeHasher {
|
||||
std::size_t operator()(const _DeviceDtypeKey& k) const noexcept {
|
||||
static at::native::ParamsHash<at::Device> device_hasher;
|
||||
static std::hash<std::string> string_hasher;
|
||||
return device_hasher(k.first) ^ string_hasher(k.second);
|
||||
}
|
||||
};
|
||||
using _FlatMap = std::unordered_map<
|
||||
_DeviceDtypeKey,
|
||||
at::native::TensorsAndIndicesT,
|
||||
_DeviceDtypeHasher>;
|
||||
py_module.def(
|
||||
"_group_tensors_by_device_and_dtype",
|
||||
[](const std::vector<std::vector<std::optional<at::Tensor>>>&
|
||||
nested_tensorlist,
|
||||
const bool with_indices) {
|
||||
_FlatMap map;
|
||||
for (const auto& iter :
|
||||
at::native::_group_tensors_by_first_tensors_device_and_dtype(
|
||||
nested_tensorlist, with_indices)) {
|
||||
const auto scalar_type_name =
|
||||
torch::utils::getDtypeNames(iter.first.second).first;
|
||||
map.insert({{iter.first.first, scalar_type_name}, iter.second});
|
||||
}
|
||||
return map;
|
||||
return at::native::_group_tensors_by_first_tensors_device_and_dtype(
|
||||
nested_tensorlist, with_indices);
|
||||
});
|
||||
|
||||
py_module.def(
|
||||
|
Reference in New Issue
Block a user