mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Created Tensor::to functions (#8643)
* Created Tensor::to functions * Only have to(dtype) and to(device) * Ignore requires_grad in TensorOptions(Tensor) constructor
This commit is contained in:
committed by
GitHub
parent
d97c9dd019
commit
065fdbd500
@ -1,23 +1,85 @@
|
||||
#include <catch.hpp>
|
||||
|
||||
#include <torch/functions.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
TEST_CASE("tensor/device-placement") {
|
||||
SECTION("DeviceGuard") {
|
||||
// SECTION("On index zero by default") {
|
||||
// auto tensor = at::ones({3, 3}, at::kCUDA);
|
||||
// REQUIRE(tensor.get_device() == 0);
|
||||
// }
|
||||
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
|
||||
REQUIRE(tensor.device().type() == at::Device((device_), (index_)).type()); \
|
||||
REQUIRE(tensor.device().index() == at::Device((device_), (index_)).index()); \
|
||||
REQUIRE(tensor.dtype() == (type_)); \
|
||||
REQUIRE(tensor.layout() == (layout_))
|
||||
|
||||
// // right hand side is TensorOptions
|
||||
// torch::OptionGuard guard = torch::device(torch::kCUDA, 1);
|
||||
// // convenience wrapper over OptionGuard
|
||||
// torch::DeviceGuard guard(torch::kCUDA, 1);
|
||||
// /// default device is CUDA
|
||||
// torch::DeviceGuard guard(1);
|
||||
TEST_CASE("Tensor/ToDtype") {
|
||||
auto tensor = at::empty({3, 4});
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
||||
|
||||
// note that this is separate from DeviceGuard. DeviceGuard should move into the
|
||||
// detail namespace and do the actual thing. OptionGuard just modifies a
|
||||
// global singleton of option defaults. It operates at a higher level.
|
||||
tensor = tensor.to(at::kInt);
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
|
||||
|
||||
tensor = tensor.to(at::kChar);
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kChar, at::kStrided);
|
||||
|
||||
tensor = tensor.to(at::kDouble);
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
|
||||
}
|
||||
|
||||
// Not currently supported.
|
||||
// TEST_CASE("Tensor/ToLayout") {
|
||||
// auto tensor = at::empty({3, 4});
|
||||
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
||||
//
|
||||
// tensor = tensor.to(at::kSparse);
|
||||
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kSparse);
|
||||
//
|
||||
// tensor = tensor.to(at::kStrided);
|
||||
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
||||
// }
|
||||
|
||||
TEST_CASE("Tensor/ToDevice", "[cuda]") {
|
||||
auto tensor = at::empty({3, 4});
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
||||
|
||||
tensor = tensor.to({at::kCUDA, 1});
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 1, at::kFloat, at::kStrided);
|
||||
|
||||
tensor = tensor.to({at::kCUDA, 0});
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 0, at::kFloat, at::kStrided);
|
||||
|
||||
tensor = tensor.to({at::kCUDA, 1});
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 1, at::kFloat, at::kStrided);
|
||||
|
||||
tensor = tensor.to(at::Device(at::kCPU));
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
||||
}
|
||||
|
||||
TEST_CASE("Tensor/ToDeviceAndDtype", "[cuda]") {
|
||||
auto tensor = at::empty({3, 4});
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
|
||||
|
||||
tensor = tensor.to({at::kCUDA, 1}, at::kInt);
|
||||
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 1, at::kInt, at::kStrided);
|
||||
}
|
||||
|
||||
TEST_CASE("Tensor/ToOptionsRespectsRequiresGrad") {
|
||||
{
|
||||
auto tensor = torch::empty({3, 4}, at::requires_grad());
|
||||
REQUIRE(tensor.requires_grad());
|
||||
|
||||
tensor = tensor.to(at::kDouble);
|
||||
REQUIRE(tensor.requires_grad());
|
||||
}
|
||||
{
|
||||
auto tensor = torch::empty({3, 4});
|
||||
REQUIRE(!tensor.requires_grad());
|
||||
|
||||
tensor = tensor.to(at::kDouble);
|
||||
REQUIRE(!tensor.requires_grad());
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("Tensor/ToDoesNotCopyWhenOptionsAreAllTheSame") {
|
||||
auto tensor = at::empty({3, 4}, at::kFloat);
|
||||
auto hopefully_not_copy = tensor.to(at::kFloat);
|
||||
REQUIRE(hopefully_not_copy.data<float>() == tensor.data<float>());
|
||||
}
|
||||
|
Reference in New Issue
Block a user