mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
``` In file included from /local/pytorch3/test/cpp/api/optim.cpp:7: local/pytorch3/test/cpp/api/support.h:44:3: warning: '~WarningCapture' overrides a destructor but is not marked 'override' [-Winconsistent-missing-destructor-override] ~WarningCapture() { ^ local/pytorch3/c10/util/Exception.h:167:11: note: overridden virtual function is here virtual ~WarningHandler() = default; ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/107191 Approved by: https://github.com/janeyx99
197 lines
5.6 KiB
C++
197 lines
5.6 KiB
C++
#pragma once
|
|
|
|
#include <test/cpp/common/support.h>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <ATen/TensorIndexing.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <torch/nn/cloneable.h>
|
|
#include <torch/types.h>
|
|
#include <torch/utils.h>
|
|
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
namespace torch {
|
|
namespace test {
|
|
|
|
// Lets you use a container without making a new class,
|
|
// for experimental implementations
|
|
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
|
|
public:
|
|
void reset() override {}
|
|
|
|
template <typename ModuleHolder>
|
|
ModuleHolder add(
|
|
ModuleHolder module_holder,
|
|
std::string name = std::string()) {
|
|
return Module::register_module(std::move(name), module_holder);
|
|
}
|
|
};
|
|
|
|
struct SeedingFixture : public ::testing::Test {
|
|
SeedingFixture() {
|
|
torch::manual_seed(0);
|
|
}
|
|
};
|
|
|
|
struct WarningCapture : public WarningHandler {
|
|
WarningCapture() : prev_(WarningUtils::get_warning_handler()) {
|
|
WarningUtils::set_warning_handler(this);
|
|
}
|
|
|
|
~WarningCapture() override {
|
|
WarningUtils::set_warning_handler(prev_);
|
|
}
|
|
|
|
const std::vector<std::string>& messages() {
|
|
return messages_;
|
|
}
|
|
|
|
std::string str() {
|
|
return c10::Join("\n", messages_);
|
|
}
|
|
|
|
void process(const c10::Warning& warning) override {
|
|
messages_.push_back(warning.msg());
|
|
}
|
|
|
|
private:
|
|
WarningHandler* prev_;
|
|
std::vector<std::string> messages_;
|
|
};
|
|
|
|
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
|
|
return first.data_ptr() == second.data_ptr();
|
|
}
|
|
|
|
// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
|
|
// torch.Tensor)` branch in `TestCase.assertEqual` in
|
|
// torch/testing/_internal/common_utils.py
|
|
inline void assert_tensor_equal(
|
|
at::Tensor a,
|
|
at::Tensor b,
|
|
bool allow_inf = false) {
|
|
ASSERT_TRUE(a.sizes() == b.sizes());
|
|
if (a.numel() > 0) {
|
|
if (a.device().type() == torch::kCPU &&
|
|
(a.scalar_type() == torch::kFloat16 ||
|
|
a.scalar_type() == torch::kBFloat16)) {
|
|
// CPU half and bfloat16 tensors don't have the methods we need below
|
|
a = a.to(torch::kFloat32);
|
|
}
|
|
if (a.device().type() == torch::kCUDA &&
|
|
a.scalar_type() == torch::kBFloat16) {
|
|
// CUDA bfloat16 tensors don't have the methods we need below
|
|
a = a.to(torch::kFloat32);
|
|
}
|
|
b = b.to(a);
|
|
|
|
if ((a.scalar_type() == torch::kBool) !=
|
|
(b.scalar_type() == torch::kBool)) {
|
|
TORCH_CHECK(false, "Was expecting both tensors to be bool type.");
|
|
} else {
|
|
if (a.scalar_type() == torch::kBool && b.scalar_type() == torch::kBool) {
|
|
// we want to respect precision but as bool doesn't support subtraction,
|
|
// boolean tensor has to be converted to int
|
|
a = a.to(torch::kInt);
|
|
b = b.to(torch::kInt);
|
|
}
|
|
|
|
auto diff = a - b;
|
|
if (a.is_floating_point()) {
|
|
// check that NaNs are in the same locations
|
|
auto nan_mask = torch::isnan(a);
|
|
ASSERT_TRUE(torch::equal(nan_mask, torch::isnan(b)));
|
|
diff.index_put_({nan_mask}, 0);
|
|
// inf check if allow_inf=true
|
|
if (allow_inf) {
|
|
auto inf_mask = torch::isinf(a);
|
|
auto inf_sign = inf_mask.sign();
|
|
ASSERT_TRUE(torch::equal(inf_sign, torch::isinf(b).sign()));
|
|
diff.index_put_({inf_mask}, 0);
|
|
}
|
|
}
|
|
// TODO: implement abs on CharTensor (int8)
|
|
if (diff.is_signed() && diff.scalar_type() != torch::kInt8) {
|
|
diff = diff.abs();
|
|
}
|
|
auto max_err = diff.max().item<double>();
|
|
ASSERT_LE(max_err, 1e-5);
|
|
}
|
|
}
|
|
}
|
|
|
|
// This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
|
|
// torch.Tensor)` branch in `TestCase.assertNotEqual` in
|
|
// torch/testing/_internal/common_utils.py
|
|
inline void assert_tensor_not_equal(at::Tensor x, at::Tensor y) {
|
|
if (x.sizes() != y.sizes()) {
|
|
return;
|
|
}
|
|
ASSERT_GT(x.numel(), 0);
|
|
y = y.type_as(x);
|
|
y = x.is_cuda() ? y.to({torch::kCUDA, x.get_device()}) : y.cpu();
|
|
auto nan_mask = x != x;
|
|
if (torch::equal(nan_mask, y != y)) {
|
|
auto diff = x - y;
|
|
if (diff.is_signed()) {
|
|
diff = diff.abs();
|
|
}
|
|
diff.index_put_({nan_mask}, 0);
|
|
// Use `item()` to work around:
|
|
// https://github.com/pytorch/pytorch/issues/22301
|
|
auto max_err = diff.max().item<double>();
|
|
ASSERT_GE(max_err, 1e-5);
|
|
}
|
|
}
|
|
|
|
inline int count_substr_occurrences(
|
|
const std::string& str,
|
|
const std::string& substr) {
|
|
int count = 0;
|
|
size_t pos = str.find(substr);
|
|
|
|
while (pos != std::string::npos) {
|
|
count++;
|
|
pos = str.find(substr, pos + substr.size());
|
|
}
|
|
|
|
return count;
|
|
}
|
|
|
|
// A RAII, thread local (!) guard that changes default dtype upon
|
|
// construction, and sets it back to the original dtype upon destruction.
|
|
//
|
|
// Usage of this guard is synchronized across threads, so that at any given
|
|
// time, only one guard can take effect.
|
|
struct AutoDefaultDtypeMode {
|
|
static std::mutex default_dtype_mutex;
|
|
|
|
AutoDefaultDtypeMode(c10::ScalarType default_dtype)
|
|
: prev_default_dtype(
|
|
torch::typeMetaToScalarType(torch::get_default_dtype())) {
|
|
default_dtype_mutex.lock();
|
|
torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype));
|
|
}
|
|
~AutoDefaultDtypeMode() {
|
|
default_dtype_mutex.unlock();
|
|
torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype));
|
|
}
|
|
c10::ScalarType prev_default_dtype;
|
|
};
|
|
|
|
inline void assert_tensor_creation_meta(
|
|
torch::Tensor& x,
|
|
torch::autograd::CreationMeta creation_meta) {
|
|
auto autograd_meta = x.unsafeGetTensorImpl()->autograd_meta();
|
|
TORCH_CHECK(autograd_meta);
|
|
auto view_meta =
|
|
static_cast<torch::autograd::DifferentiableViewMeta*>(autograd_meta);
|
|
TORCH_CHECK(view_meta->has_bw_view());
|
|
ASSERT_EQ(view_meta->get_creation_meta(), creation_meta);
|
|
}
|
|
} // namespace test
|
|
} // namespace torch
|