mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 17:54:55 +08:00
* Created TensorOptions
Storing the type in TensorOptions to solve the Variable problem
Created convenience creation functions for TensorOptions and added tests
Converted zeros to TensorOptions
Converted rand to TensorOptions
Fix codegen for TensorOptions and multiple arguments
Put TensorOptions convenience functions into torch namespace too
All factory functions except *_like support TensorOptions
Integrated with recent JIT changes
Support *_like functions
Fix in place modification
Some cleanups and fixes
Support sparse_coo_tensor
Fix bug in Type.cpp
Fix .empty calls in C++ API
Fix bug in Type.cpp
Trying to fix device placement
Make AutoGPU CPU compatible
Remove some auto_gpu.h uses
Fixing some headers
Fix some remaining CUDA/AutoGPU issues
Fix some AutoGPU uses
Fixes to dispatch_tensor_conversion
Reset version of new variables to zero
Implemented parsing device strings
Random fixes to tests
Self review cleanups
flake8
Undo changes to variable.{h,cpp} because they fail on gcc7.2
Add [cuda] tag to tensor_options_cuda.cpp
Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks
Fix linker error in AutoGPU.cpp
Fix bad merge conflict in native_functions.yaml
Fixed caffe2/contrib/aten
Fix new window functions added to TensorFactories.cpp
* Removed torch::TensorOptions
Added code to generate wrapper functions for factory methods
Add implicit constructor from Backend to TensorOptions
Remove Var() from C++ API and use torch:: functions
Use torch:: functions more subtly in C++ API
Make AutoGPU::set_device more exception safe
Check status directly in DynamicCUDAHooksInterface
Rename AutoGPU to DeviceGuard
Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad
remove python_default_init: self.type()
Add back original factory functions, but with deprecation warnings
Disable DeviceGuard for a couple functions in ATen
Remove print statement
Fix DeviceGuard construction from undefined tensor
Fixing CUDA device compiler issues
Moved as many methods as possible into header files
Dont generate python functions for deprecated factories
Remove merge conflict artefact
Fix tensor_options_cuda.cpp
Fix set_requires_grad not being checked
Fix tensor_new.h
TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac
Fix bug in DeviceGuard.h
Missing includes
TEMPORARILY moving a few more methods into .cpp to see if it fixes windows
Fixing linker errors
* Fix up SummaryOps to use new factories
Undo device agnostic behavior of DeviceGuard
Use -1 instead of optional for default device index
Also move DeviceGuard methods into header
Fixes around device index after optional -> int32_t switch
Fix use of DeviceGuard in new_with_tensor_copy
Fix tensor_options.cpp
* Fix Type::copy(
* Remove test_non_float_params from ONNX tests
* Set requires_grad=False in ONNX tests that use ints
* Put layout/dtype/device on Tensor
* Post merge fixes
* Change behavior of DeviceGuard to match AutoGPU
* Fix C++ API integration tests
* Fix flip functions
348 lines
9.6 KiB
C++
348 lines
9.6 KiB
C++
#include <catch.hpp>
|
|
|
|
#include <torch/detail/ordered_dict.h>
|
|
#include <torch/expanding_array.h>
|
|
#include <torch/functions.h>
|
|
#include <torch/nn/modules/linear.h>
|
|
#include <torch/tensor.h>
|
|
#include <torch/utils.h>
|
|
|
|
#include <torch/csrc/utils/memory.h>
|
|
|
|
#include <ATen/optional.h>
|
|
|
|
using namespace torch;
|
|
using namespace torch::nn;
|
|
|
|
template <typename T>
|
|
using OrderedDict = detail::OrderedDict<std::string, T>;
|
|
|
|
using Catch::StartsWith;
|
|
|
|
TEST_CASE("misc") {
|
|
SECTION("no_grad") {
|
|
NoGradGuard guard;
|
|
auto model = Linear(5, 2).build();
|
|
auto x = torch::randn({10, 5}, at::requires_grad());
|
|
auto y = model->forward({x})[0];
|
|
Variable s = y.sum();
|
|
|
|
s.backward();
|
|
REQUIRE(!model->parameters()["weight"].grad().defined());
|
|
}
|
|
|
|
SECTION("CPU random seed") {
|
|
int size = 100;
|
|
torch::manual_seed(7);
|
|
auto x1 = torch::randn({size});
|
|
torch::manual_seed(7);
|
|
auto x2 = torch::randn({size});
|
|
|
|
auto l_inf = (x1.data() - x2.data()).abs().max().toCFloat();
|
|
REQUIRE(l_inf < 1e-10);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("misc_cuda", "[cuda]") {
|
|
SECTION("CUDA random seed") {
|
|
int size = 100;
|
|
torch::manual_seed(7);
|
|
auto x1 = torch::randn({size}, at::kCUDA);
|
|
torch::manual_seed(7);
|
|
auto x2 = torch::randn({size}, at::kCUDA);
|
|
|
|
auto l_inf = (x1.data() - x2.data()).abs().max().toCFloat();
|
|
REQUIRE(l_inf < 1e-10);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("autograd") {
|
|
auto x = torch::randn({3, 3}, at::requires_grad());
|
|
auto y = torch::randn({3, 3});
|
|
auto z = x * y;
|
|
SECTION("derivatives of zero-dim tensors") {
|
|
z.sum().backward();
|
|
REQUIRE(x.grad().allclose(y));
|
|
}
|
|
SECTION("derivatives of tensors") {
|
|
z.backward();
|
|
REQUIRE(x.grad().allclose(y));
|
|
}
|
|
SECTION("custom gradient inputs") {
|
|
z.sum().backward(torch::ones({}) * 2);
|
|
REQUIRE(x.grad().allclose(y * 2));
|
|
}
|
|
// Assume everything else is safe from PyTorch tests.
|
|
}
|
|
|
|
TEST_CASE("expanding-array") {
|
|
SECTION("successful construction") {
|
|
SECTION("initializer_list") {
|
|
ExpandingArray<5> e({1, 2, 3, 4, 5});
|
|
REQUIRE(e.size() == 5);
|
|
for (size_t i = 0; i < e.size(); ++i) {
|
|
REQUIRE((*e)[i] == i + 1);
|
|
}
|
|
}
|
|
|
|
SECTION("vector") {
|
|
ExpandingArray<5> e(std::vector<int64_t>{1, 2, 3, 4, 5});
|
|
REQUIRE(e.size() == 5);
|
|
for (size_t i = 0; i < e.size(); ++i) {
|
|
REQUIRE((*e)[i] == i + 1);
|
|
}
|
|
}
|
|
|
|
SECTION("array") {
|
|
ExpandingArray<5> e(std::array<int64_t, 5>({1, 2, 3, 4, 5}));
|
|
REQUIRE(e.size() == 5);
|
|
for (size_t i = 0; i < e.size(); ++i) {
|
|
REQUIRE((*e)[i] == i + 1);
|
|
}
|
|
}
|
|
|
|
SECTION("single value") {
|
|
ExpandingArray<5> e(5);
|
|
REQUIRE(e.size() == 5);
|
|
for (size_t i = 0; i < e.size(); ++i) {
|
|
REQUIRE((*e)[i] == 5);
|
|
}
|
|
}
|
|
}
|
|
SECTION("throws for incorrect size on construction") {
|
|
SECTION("initializer_list") {
|
|
REQUIRE_THROWS_WITH(
|
|
ExpandingArray<5>({1, 2, 3, 4, 5, 6, 7}),
|
|
StartsWith("Expected 5 values, but instead got 7"));
|
|
}
|
|
SECTION("vector") {
|
|
REQUIRE_THROWS_WITH(
|
|
ExpandingArray<5>(std::vector<int64_t>({1, 2, 3, 4, 5, 6, 7})),
|
|
StartsWith("Expected 5 values, but instead got 7"));
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST_CASE("make_unique") {
|
|
struct Test {
|
|
explicit Test(const int& x) : lvalue_(x) {}
|
|
explicit Test(int&& x) : rvalue_(x) {}
|
|
|
|
at::optional<int> lvalue_;
|
|
at::optional<int> rvalue_;
|
|
};
|
|
|
|
SECTION("forwards rvalues correctly") {
|
|
auto ptr = torch::make_unique<Test>(123);
|
|
REQUIRE(!ptr->lvalue_.has_value());
|
|
REQUIRE(ptr->rvalue_.has_value());
|
|
REQUIRE(*ptr->rvalue_ == 123);
|
|
}
|
|
|
|
SECTION("forwards lvalues correctly") {
|
|
int x = 5;
|
|
auto ptr = torch::make_unique<Test>(x);
|
|
REQUIRE(ptr->lvalue_.has_value());
|
|
REQUIRE(*ptr->lvalue_ == 5);
|
|
REQUIRE(!ptr->rvalue_.has_value());
|
|
}
|
|
|
|
SECTION("Can construct unique_ptr of array") {
|
|
auto ptr = torch::make_unique<int[]>(3);
|
|
// Value initialization is required by the standard.
|
|
REQUIRE(ptr[0] == 0);
|
|
REQUIRE(ptr[1] == 0);
|
|
REQUIRE(ptr[2] == 0);
|
|
}
|
|
}
|
|
|
|
TEST_CASE("ordered-dict") {
|
|
SECTION("is empty after default construction") {
|
|
OrderedDict<int> dict;
|
|
REQUIRE(dict.subject() == "Key");
|
|
REQUIRE(dict.is_empty());
|
|
REQUIRE(dict.size() == 0);
|
|
}
|
|
|
|
SECTION("insert inserts elements when they are not yet present") {
|
|
OrderedDict<int> dict;
|
|
dict.insert("a", 1);
|
|
dict.insert("b", 2);
|
|
REQUIRE(dict.size() == 2);
|
|
}
|
|
|
|
SECTION("get returns values when present") {
|
|
OrderedDict<int> dict;
|
|
dict.insert("a", 1);
|
|
dict.insert("b", 2);
|
|
REQUIRE(dict.get("a") == 1);
|
|
REQUIRE(dict.get("b") == 2);
|
|
}
|
|
|
|
SECTION("get throws when passed keys that are not present") {
|
|
OrderedDict<int> dict;
|
|
dict.insert("a", 1);
|
|
dict.insert("b", 2);
|
|
REQUIRE_THROWS_WITH(
|
|
dict.get("foo"), StartsWith("Key 'foo' is not defined"));
|
|
REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
|
|
}
|
|
|
|
SECTION("can initialize from list") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE(dict.size() == 2);
|
|
REQUIRE(dict.get("a") == 1);
|
|
REQUIRE(dict.get("b") == 2);
|
|
}
|
|
|
|
SECTION("insert throws when passed elements that are present") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE_THROWS_WITH(
|
|
dict.insert("a", 1), StartsWith("Key 'a' already defined"));
|
|
REQUIRE_THROWS_WITH(
|
|
dict.insert("b", 1), StartsWith("Key 'b' already defined"));
|
|
}
|
|
|
|
SECTION("front() returns the first item") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE(dict.front().key == "a");
|
|
REQUIRE(dict.front().value == 1);
|
|
}
|
|
|
|
SECTION("back() returns the last item") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE(dict.back().key == "b");
|
|
REQUIRE(dict.back().value == 2);
|
|
}
|
|
|
|
SECTION("find returns pointers to values when present") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE(dict.find("a") != nullptr);
|
|
REQUIRE(*dict.find("a") == 1);
|
|
REQUIRE(dict.find("b") != nullptr);
|
|
REQUIRE(*dict.find("b") == 2);
|
|
}
|
|
|
|
SECTION("find returns null pointers when passed keys that are not present") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE(dict.find("bar") == nullptr);
|
|
REQUIRE(dict.find("") == nullptr);
|
|
}
|
|
|
|
SECTION("operator[] returns values when passed keys that are present") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE(dict["a"] == 1);
|
|
REQUIRE(dict["b"] == 2);
|
|
}
|
|
|
|
SECTION("operator[] returns items positionally when passed integers") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE(dict[0].key == "a");
|
|
REQUIRE(dict[0].value == 1);
|
|
REQUIRE(dict[1].key == "b");
|
|
REQUIRE(dict[1].value == 2);
|
|
}
|
|
|
|
SECTION("operator[] throws when passed keys that are not present") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE_THROWS_WITH(
|
|
dict.get("foo"), StartsWith("Key 'foo' is not defined"));
|
|
REQUIRE_THROWS_WITH(dict.get(""), StartsWith("Key '' is not defined"));
|
|
}
|
|
|
|
SECTION("update inserts all items from another OrderedDict") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
OrderedDict<int> dict2 = {{"c", 3}};
|
|
dict2.update(dict);
|
|
REQUIRE(dict2.size() == 3);
|
|
REQUIRE(dict2.find("a") != nullptr);
|
|
REQUIRE(dict2.find("b") != nullptr);
|
|
REQUIRE(dict2.find("c") != nullptr);
|
|
}
|
|
|
|
SECTION("update also checks for duplicates") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
OrderedDict<int> dict2 = {{"a", 1}};
|
|
REQUIRE_THROWS_WITH(
|
|
dict2.update(dict), StartsWith("Key 'a' already defined"));
|
|
}
|
|
|
|
SECTION("Can iterate items") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
auto iterator = dict.begin();
|
|
REQUIRE(iterator != dict.end());
|
|
REQUIRE(iterator->key == "a");
|
|
REQUIRE(iterator->value == 1);
|
|
++iterator;
|
|
REQUIRE(iterator != dict.end());
|
|
REQUIRE(iterator->key == "b");
|
|
REQUIRE(iterator->value == 2);
|
|
++iterator;
|
|
REQUIRE(iterator == dict.end());
|
|
}
|
|
|
|
SECTION("clear makes the dict empty") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
REQUIRE(!dict.is_empty());
|
|
dict.clear();
|
|
REQUIRE(dict.is_empty());
|
|
}
|
|
|
|
SECTION("can copy construct") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
OrderedDict<int> copy = dict;
|
|
REQUIRE(copy.size() == 2);
|
|
REQUIRE(*copy[0] == 1);
|
|
REQUIRE(*copy[1] == 2);
|
|
}
|
|
|
|
SECTION("can copy assign") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
OrderedDict<int> copy = {{"c", 1}};
|
|
REQUIRE(copy.find("c") != nullptr);
|
|
copy = dict;
|
|
REQUIRE(copy.size() == 2);
|
|
REQUIRE(*copy[0] == 1);
|
|
REQUIRE(*copy[1] == 2);
|
|
REQUIRE(copy.find("c") == nullptr);
|
|
}
|
|
|
|
SECTION("can move construct") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
OrderedDict<int> copy = std::move(dict);
|
|
REQUIRE(copy.size() == 2);
|
|
REQUIRE(*copy[0] == 1);
|
|
REQUIRE(*copy[1] == 2);
|
|
}
|
|
|
|
SECTION("can move assign") {
|
|
OrderedDict<int> dict = {{"a", 1}, {"b", 2}};
|
|
OrderedDict<int> copy = {{"c", 1}};
|
|
REQUIRE(copy.find("c") != nullptr);
|
|
copy = std::move(dict);
|
|
REQUIRE(copy.size() == 2);
|
|
REQUIRE(*copy[0] == 1);
|
|
REQUIRE(*copy[1] == 2);
|
|
REQUIRE(copy.find("c") == nullptr);
|
|
}
|
|
|
|
SECTION("can insert with braces") {
|
|
OrderedDict<std::pair<int, int>> dict;
|
|
dict.insert("a", {1, 2});
|
|
REQUIRE(!dict.is_empty());
|
|
REQUIRE(dict["a"].first == 1);
|
|
REQUIRE(dict["a"].second == 2);
|
|
}
|
|
|
|
SECTION("Error messages include the what") {
|
|
OrderedDict<int> dict("Penguin");
|
|
REQUIRE(dict.subject() == "Penguin");
|
|
dict.insert("a", 1);
|
|
REQUIRE(!dict.is_empty());
|
|
REQUIRE_THROWS_WITH(
|
|
dict.get("b"), StartsWith("Penguin 'b' is not defined"));
|
|
REQUIRE_THROWS_WITH(
|
|
dict.insert("a", 1), StartsWith("Penguin 'a' already defined"));
|
|
}
|
|
}
|