Hashing logic for c10::complex (#51441)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51441

Test Plan: Imported from OSS

Reviewed By: ezyang

Differential Revision: D26170195

Pulled By: anjali411

fbshipit-source-id: 9247c1329229405426cfbd8463cabcdbe5bdb740
This commit is contained in:
anjali411
2021-02-01 15:44:19 -08:00
committed by Facebook GitHub Bot
parent 8fa328f88e
commit 09bc58796e
2 changed files with 23 additions and 1 deletions

View File

@ -3,7 +3,9 @@
#include <sstream>
#include <c10/util/complex.h>
#include <c10/macros/Macros.h>
#include <c10/util/hash.h>
#include <gtest/gtest.h>
#include <unordered_map>
#if (defined(__CUDACC__) || defined(__HIPCC__))
#define MAYBE_GLOBAL __global__
@ -167,6 +169,17 @@ TEST(TestConstructors, FromThrust) {
}
#endif
TEST(TestConstructors, UnorderedMap) {
std::unordered_map<c10::complex<double>, c10::complex<double>, c10::hash<c10::complex<double>>> m;
auto key1 = c10::complex<double>(2.5, 3);
auto key2 = c10::complex<double>(2, 0);
auto val1 = c10::complex<double>(2, -3.2);
auto val2 = c10::complex<double>(0, -3);
m[key1] = val1;
m[key2] = val2;
ASSERT_EQ(m[key1], val1);
ASSERT_EQ(m[key2], val2);
}
} // constructors

View File

@ -2,7 +2,7 @@
#include <functional>
#include <vector>
#include <c10/util/complex.h>
namespace c10 {
// NOTE: hash_combine is based on implementation from Boost
@ -139,4 +139,13 @@ size_t get_hash(const Types&... args) {
return c10::hash<decltype(std::tie(args...))>()(std::tie(args...));
}
// Specialization for c10::complex
template <typename T>
struct hash<c10::complex<T>> {
size_t operator()(const c10::complex<T>& c) const {
return get_hash(c.real(), c.imag());
}
};
} // namespace c10