Deprecate tensor.type() (#30281)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/29161.

I looked a bit at the code changes related to this and think I have all of the use cases of `DeprecatedTypeProperties` covered in the message, but suggestions from someone with more context on this would be very much appreciated :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30281

Differential Revision: D18830818

Pulled By: ezyang

fbshipit-source-id: 1a7fcee15354ae09e6644577e7fa33bd26acfe20
This commit is contained in:
Nathan Goldbaum
2019-12-05 10:53:32 -08:00
committed by Facebook Github Bot
parent 2171f91053
commit f531815526
53 changed files with 254 additions and 209 deletions

View File

@ -99,7 +99,7 @@ void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, cons
TORCH_CHECK(values.device().type() == device().type(), "device type of values (", values.device().type(), ") must match device type of device().type()", device().type(), ")");
TORCH_CHECK(values.scalar_type() == typeMetaToScalarType(dtype()), "dtype of values (", values.scalar_type(), ") must match dtype of sparse tensor (", typeMetaToScalarType(dtype()), ")");
TORCH_CHECK(indices.scalar_type() == kLong, "indices must be an int64 tensor");
TORCH_CHECK(indices.type().backend() == values.type().backend(), "backend of indices (", indices.type().backend(), ") must match backend of values (", values.type().backend(), ")");
TORCH_CHECK(indices.options().backend() == values.options().backend(), "backend of indices (", indices.options().backend(), ") must match backend of values (", values.options().backend(), ")");
TORCH_CHECK(!indices.is_cuda() || indices.get_device() == values.get_device(), "device of indices (", indices.get_device(), ") must match device of values (", values.get_device(), ")");
TORCH_CHECK(indices.dim() == 2, "indices must be sparse_dim x nnz, but got: ", indices.sizes());

View File

@ -134,7 +134,7 @@ void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors) {
void checkSameType(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
TORCH_CHECK(
t1->type() == t2->type(),
t1->options().type_equal(t2->options()),
"Expected tensor for ", t1, " to have the same type as tensor for ", t2,
"; but type ", t1->toString(), " does not equal ", t2->toString(),
" (while checking arguments for ", c, ")");
@ -196,9 +196,9 @@ void checkAllDefined(CheckedFrom c, ArrayRef<TensorArg> ts) {
void checkBackend(CheckedFrom c, const Tensor& t, Backend backend) {
TORCH_CHECK(
!t.defined() || t.type().backend() == backend,
!t.defined() || t.options().backend() == backend,
"Expected tensor to have ", toString(backend),
" Backend, but got tensor with ", toString(t.type().backend()), " Backend ",
" Backend, but got tensor with ", toString(t.options().backend()), " Backend ",
"(while checking arguments for ", c, ")");
}
@ -210,9 +210,9 @@ void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> tensors, at::Backend backe
void checkDeviceType(CheckedFrom c, const Tensor& t, DeviceType device_type) {
TORCH_CHECK(
!t.defined() || t.type().device_type() == device_type,
!t.defined() || t.device().type() == device_type,
"Expected tensor to have ", device_type,
" DeviceType, but got tensor with ", t.type().device_type(), " DeviceType ",
" DeviceType, but got tensor with ", t.device().type(), " DeviceType ",
"(while checking arguments for ", c, ")");
}

View File

@ -28,14 +28,20 @@ void Tensor::enforce_invariants() {
void Tensor::print() const {
if (defined()) {
std::cerr << "[" << type().toString() << " " << sizes() << "]" << std::endl;
std::cerr << "[" << toString() << " " << sizes() << "]" << std::endl;
} else {
std::cerr << "[UndefinedTensor]" << std::endl;
}
}
std::string Tensor::toString() const {
return type().toString();
std::string base_str;
if (scalar_type() == ScalarType::Undefined) {
base_str = "UndefinedType";
} else {
base_str = std::string(at::toString(options().backend())) + at::toString(scalar_type()) + "Type";
}
return base_str;
}
Tensor Tensor::variable_data() const {

View File

@ -179,7 +179,7 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {
return false;
}
return (input.is_mkldnn()) || // input is mkldnn Tensor
(input.type().backend() == at::Backend::CPU &&
(input.options().backend() == at::Backend::CPU &&
input.scalar_type() == kFloat && // only on CPU Float Tensors
!is_dilated() && // doesn't support dilation
!transposed && // or transposed tensors
@ -190,7 +190,7 @@ auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {
auto ConvParams::use_nnpack(const at::Tensor& input) const -> bool {
#if AT_NNPACK_ENABLED()
return at::_nnpack_available() &&
input.type().backend() == at::Backend::CPU &&
input.options().backend() == at::Backend::CPU &&
input.scalar_type() == kFloat && // only on CPU Float Tensors
!is_dilated() && // or dilation
!transposed && // or transposed tensors
@ -594,11 +594,11 @@ at::Tensor _convolution(
output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation);
}
} else if (params.use_cudnn(input)) {
TORCH_CHECK(input.type() == weight.type(),
"Input type (", input.type().toString(), ") and weight type (", weight.type().toString(),
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same");
TORCH_CHECK(!bias.defined() || (input.type() == bias.type()),
"Input type (", input.type().toString(), ") and bias type (", bias.type().toString(),
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same");
if (params.transposed) {
@ -611,11 +611,11 @@ at::Tensor _convolution(
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
}
} else if (params.use_miopen(input)) {
TORCH_CHECK(input.type() == weight.type(),
"Input type (", input.type().toString(), ") and weight type (", weight.type().toString(),
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same");
TORCH_CHECK(!bias.defined() || (input.type() == bias.type()),
"Input type (", input.type().toString(), ") and bias type (", bias.type().toString(),
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same");
if (params.transposed) {
@ -629,11 +629,11 @@ at::Tensor _convolution(
}
} else if (params.use_mkldnn(input)) {
#if AT_MKLDNN_ENABLED()
TORCH_CHECK(input.type() == weight.type(),
"Input type (", input.type().toString(), ") and weight type (", weight.type().toString(),
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same");
TORCH_CHECK(!bias.defined() || (input.type() == bias.type()),
"Input type (", input.type().toString(), ") and bias type (", bias.type().toString(),
TORCH_CHECK(!bias.defined() || (input.options().type_equal(bias.options())),
"Input type (", input.toString(), ") and bias type (", bias.toString(),
") should be the same");
if (!input_is_mkldnn) {
output = at::mkldnn_convolution(input.contiguous(), weight.contiguous(), bias.defined() ? bias.contiguous() : bias,

View File

@ -95,7 +95,7 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
return at::copy_sparse_to_sparse_(self, src, non_blocking);
} else if (self.is_sparse() || src.is_sparse()) {
AT_ERROR("copy_() between dense and sparse Tensors is not implemented! Found self type = ",
self.type(), " and src type = ", src.type());
self.toString(), " and src type = ", src.toString());
}
if (self.is_same(src)) {

View File

@ -15,11 +15,11 @@ Tensor cross(const Tensor & input, const Tensor & other, const c10::optional<int
}
Tensor & cross_out(Tensor & out, const Tensor & input, const Tensor & other, const c10::optional<int64_t> dimension) {
auto device_res = input.type().device_type();
auto device_res = input.device().type();
TORCH_CHECK(device_res == kCPU || device_res == kCUDA, "cross only supports CPU and CUDA devices, out got: ", device_res);
auto device1 = input.type().device_type();
auto device1 = input.device().type();
TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "cross only supports CPU and CUDA devices, input got: ", device1);
auto device2 = other.type().device_type();
auto device2 = other.device().type();
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "cross only supports CPU and CUDA devices, other got: ", device2);
TORCH_CHECK(device_res == device1, "out and input must have the same device type. out: ", device_res, " input: ", device1);
TORCH_CHECK(device1 == device2, "input and other must have the same device type. input: ", device1, " other: ", device2);

View File

@ -39,10 +39,10 @@ Tensor euclidean_dist_out(const Tensor& x1, const Tensor& x2) {
static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type());
auto device1 = x1.type().device_type();
auto device1 = x1.device().type();
TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "cdist only supports CPU and CUDA devices, X1 got: ", device1);
TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type());
auto device2 = x2.type().device_type();
auto device2 = x2.device().type();
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "cdist only supports CPU and CUDA devices, X2 got: ", device2);
TORCH_CHECK(p >= 0, "cdist only supports non-negative p values");
TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
@ -123,9 +123,9 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c
TORCH_CHECK(grad.is_contiguous(), "_cdist_backward requires grad to be contiguous");
int64_t n = x1.size(-2);
int64_t m = x1.size(-1);
auto device1 = x1.type().device_type();
auto device1 = x1.device().type();
TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X1 got: ", device1);
auto device2 = x2.type().device_type();
auto device2 = x2.device().type();
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2);
IntArrayRef batch_tensor1(x1.sizes().data(), std::max<int64_t>(x1.dim() - 2, 0));
int batch_product = std::accumulate(batch_tensor1.begin(), batch_tensor1.end(), 1, std::multiplies<int64_t>());
@ -136,7 +136,7 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c
Tensor _pdist_forward(const Tensor& self, const double p) {
TORCH_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input");
auto device = self.type().device_type();
auto device = self.device().type();
TORCH_CHECK(device == kCPU || device == kCUDA, "_pdist_forward only supports CPU and CUDA devices, got: ", device);
Tensor result = at::empty({0}, self.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
if (self.size(0) <= 1) {
@ -157,7 +157,7 @@ Tensor _pdist_forward(const Tensor& self, const double p) {
Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, const Tensor& pdist) {
TORCH_CHECK(self.is_contiguous(), "_pdist_backward requires self to be contiguous");
TORCH_CHECK(pdist.is_contiguous(), "_pdist_backward requires pdist to be contiguous");
auto device = self.type().device_type();
auto device = self.device().type();
TORCH_CHECK(device == kCPU || device == kCUDA, "_pdist_backward only supports CPU and CUDA devices, got: ", device);
Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
pdist_backward_stub(device, result, grad, self, p, pdist);

View File

@ -321,7 +321,7 @@ Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bo
} else {
result.resize_({n_sample});
}
multinomial_stub(result.type().device_type(), result, self, n_sample, with_replacement, gen);
multinomial_stub(result.device().type(), result, self, n_sample, with_replacement, gen);
return result;
}

View File

@ -165,7 +165,7 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
// For CUDA tensors, force all index tensors to have the same striding to
// simplify the CUDA kernel.
if (indices.size() >= 2 && this->src.type().device_type() == kCUDA) {
if (indices.size() >= 2 && this->src.device().type() == kCUDA) {
if (!all_strides_match(indices)) {
for (size_t i = 0; i < indices.size(); i++) {
indices[i] = indices[i].contiguous();
@ -251,8 +251,8 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu
if (indices.size() > (size_t)self.dim()) {
AT_INDEX_ERROR("too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
}
if (accumulate && self.type().device_type() == kCUDA) {
index_put_accum_stub(self.type().device_type(), self, indices, value, unsafe);
if (accumulate && self.device().type() == kCUDA) {
index_put_accum_stub(self.device().type(), self, indices, value, unsafe);
return self;
}
auto info = make_info(self, indices);

View File

@ -88,7 +88,7 @@ std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
Tensor pinverse(const Tensor& self, double rcond) {
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() >= 2,
"pinverse(", self.type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions "
"pinverse(", self.scalar_type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions "
"of floating types");
if (self.numel() == 0) {
// Match NumPy
@ -118,7 +118,7 @@ static inline Tensor _matrix_rank_helper(const Tensor& self, bool symmetric) {
Tensor matrix_rank(const Tensor& self, double tol, bool symmetric) {
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() == 2,
"matrix_rank(", self.type(), "{", self.sizes(), "}): expected a 2D tensor "
"matrix_rank(", self.scalar_type(), "{", self.sizes(), "}): expected a 2D tensor "
"of floating types");
Tensor S = _matrix_rank_helper(self, symmetric);
@ -127,7 +127,7 @@ Tensor matrix_rank(const Tensor& self, double tol, bool symmetric) {
Tensor matrix_rank(const Tensor& self, bool symmetric) {
TORCH_CHECK((at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type())) && self.dim() == 2,
"matrix_rank(", self.type(), "{", self.sizes(), "}): expected a 2D tensor "
"matrix_rank(", self.scalar_type(), "{", self.sizes(), "}): expected a 2D tensor "
"of floating types");
Tensor S = _matrix_rank_helper(self, symmetric);
@ -479,7 +479,7 @@ Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor
Tensor matrix_power(const Tensor& a, int64_t n) {
TORCH_CHECK(a.dim() >= 2 && (at::isFloatingType(a.scalar_type()) || at::isComplexType(a.scalar_type())),
"matrix_power(", a.type(), "{", a.sizes(), "}): expected a tensor "
"matrix_power(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor "
"of floating types with dim at least 2");
if (n == 0) {
return a.clone(at::MemoryFormat::Contiguous).copy_(at::eye(a.size(-2), a.options()).expand_as(a));

View File

@ -14,8 +14,8 @@ bool is_pinned(const Tensor& self) {
}
Tensor pin_memory(const Tensor& self) {
if (self.type().backend() != Backend::CPU) {
AT_ERROR("cannot pin '", self.type().toString(), "' only dense CPU tensors can be pinned");
if (self.options().backend() != Backend::CPU) {
AT_ERROR("cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
}
if (self.is_pinned()) {
return self;

View File

@ -4,7 +4,7 @@
namespace at { namespace native {
void checkLongTensor(const Tensor& tensor) {
TORCH_CHECK(tensor.dim() == 1 && tensor.type().device_type() == at::kCPU && tensor.scalar_type() == at::kLong,
TORCH_CHECK(tensor.dim() == 1 && tensor.device().type() == at::kCPU && tensor.scalar_type() == at::kLong,
"'lengths' argument should be a 1D CPU int64 tensor");
}

View File

@ -35,7 +35,7 @@ Tensor& addcmul_out(
const Tensor& tensor1,
const Tensor& tensor2,
Scalar value) {
checkBackend("addcmul_cpu", result, self.type().backend());
checkBackend("addcmul_cpu", result, self.options().backend());
auto iter = at::TensorIterator();
iter.set_check_mem_overlap(true);
iter.add_output(result);
@ -70,7 +70,7 @@ Tensor& addcdiv_out(
const Tensor& tensor1,
const Tensor& tensor2,
Scalar value) {
checkBackend("addcdiv_cpu", result, self.type().backend());
checkBackend("addcdiv_cpu", result, self.options().backend());
auto iter = at::TensorIterator();
iter.set_check_mem_overlap(true);
iter.add_output(result);

View File

@ -850,13 +850,13 @@ std::tuple<Tensor, Tensor> NAME( \
bool batch_first) { \
if (at::cudnn_is_acceptable(_input)) { \
Tensor output, hy; \
NAME##_cudnn_stub(_input.type().device_type(), output, hy, _input, hx, _params, has_biases, \
NAME##_cudnn_stub(_input.device().type(), output, hy, _input, hx, _params, has_biases, \
num_layers, dropout_p, train, bidirectional, batch_first); \
return std::make_tuple(std::move(output), std::move(hy)); \
} \
if (use_miopen(_input, dropout_p)) { \
Tensor output, hy; \
NAME##_miopen_stub(_input.type().device_type(), output, hy, _input, hx, _params, has_biases, \
NAME##_miopen_stub(_input.device().type(), output, hy, _input, hx, _params, has_biases, \
num_layers, dropout_p, train, bidirectional, batch_first); \
return std::make_tuple(std::move(output), std::move(hy)); \
} \
@ -883,13 +883,13 @@ std::tuple<Tensor, Tensor> NAME( \
bool bidirectional) { \
if (at::cudnn_is_acceptable(data)) { \
Tensor output, hy; \
NAME##_packed_cudnn_stub(data.type().device_type(), output, hy, data, batch_sizes, hx, \
NAME##_packed_cudnn_stub(data.device().type(), output, hy, data, batch_sizes, hx, \
_params, has_biases, num_layers, dropout_p, train, bidirectional); \
return std::make_tuple(std::move(output), std::move(hy)); \
} \
if (use_miopen(data, dropout_p)) { \
Tensor output, hy; \
NAME##_packed_miopen_stub(data.type().device_type(), output, hy, data, batch_sizes, hx, \
NAME##_packed_miopen_stub(data.device().type(), output, hy, data, batch_sizes, hx, \
_params, has_biases, num_layers, dropout_p, train, bidirectional); \
return std::make_tuple(std::move(output), std::move(hy)); \
} \
@ -914,7 +914,7 @@ std::tuple<Tensor, Tensor> NAME( \
bool batch_first) { \
if (at::cudnn_is_acceptable(_input)) { \
Tensor output, hy; \
gru_cudnn_stub(_input.type().device_type(), output, hy, _input, hx, _params, has_biases, \
gru_cudnn_stub(_input.device().type(), output, hy, _input, hx, _params, has_biases, \
num_layers, dropout_p, train, bidirectional, batch_first); \
return std::make_tuple(std::move(output), std::move(hy)); \
} \
@ -941,7 +941,7 @@ std::tuple<Tensor, Tensor> NAME( \
bool bidirectional) { \
if (at::cudnn_is_acceptable(data)) { \
Tensor output, hy; \
gru_packed_cudnn_stub(data.type().device_type(), output, hy, data, batch_sizes, hx, \
gru_packed_cudnn_stub(data.device().type(), output, hy, data, batch_sizes, hx, \
_params, has_biases, num_layers, dropout_p, train, bidirectional); \
return std::make_tuple(std::move(output), std::move(hy)); \
} \
@ -976,14 +976,14 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
if (at::cudnn_is_acceptable(_input)) {
Tensor output, hy, cy;
lstm_cudnn_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases,
lstm_cudnn_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
if (use_miopen(_input, dropout_p)) {
Tensor output, hy, cy;
lstm_miopen_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases,
lstm_miopen_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
@ -1005,14 +1005,14 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
if (at::cudnn_is_acceptable(data)) {
Tensor output, hy, cy;
lstm_packed_cudnn_stub(data.type().device_type(), output, hy, cy, data, batch_sizes, hx,
lstm_packed_cudnn_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
_params, has_biases, num_layers, dropout_p, train, bidirectional);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
if (use_miopen(data, dropout_p)) {
Tensor output, hy, cy;
lstm_packed_miopen_stub(data.type().device_type(), output, hy, cy, data, batch_sizes, hx,
lstm_packed_miopen_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
_params, has_biases, num_layers, dropout_p, train, bidirectional);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
@ -1154,7 +1154,7 @@ std::tuple<Tensor, Tensor, Tensor> quantized_lstm(
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
if (at::cudnn_is_acceptable(_input)) {
Tensor output, hy, cy;
lstm_cudnn_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases,
lstm_cudnn_stub(_input.device().type(), output, hy, cy, _input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
@ -1202,7 +1202,7 @@ std::tuple<Tensor, Tensor, Tensor> quantized_lstm(
TORCH_CHECK(hx.size() == 2, "lstm expects two hidden states");
if (at::cudnn_is_acceptable(data)) {
Tensor output, hy, cy;
lstm_packed_cudnn_stub(data.type().device_type(), output, hy, cy, data, batch_sizes, hx,
lstm_packed_cudnn_stub(data.device().type(), output, hy, cy, data, batch_sizes, hx,
_params, has_biases, num_layers, dropout_p, train, bidirectional);
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}

View File

@ -136,9 +136,9 @@ static TensorIterator make_reduction(
for (const Tensor *t: {&result1, &result2}) {
const Tensor& result = *t;
TORCH_CHECK(
!result.defined() || result.type().scalarType() == dtype,
!result.defined() || result.scalar_type() == dtype,
name, ": provided dtype must match dtype of result. Got ",
toString(result.type().scalarType()),
toString(result.scalar_type()),
" and ",
toString(dtype),
".");
@ -161,8 +161,8 @@ static TensorIterator make_reduction(
// efficiency.
// We don't generalize this to common mismatched input/output types to avoid cross
// product of templated kernel launches.
if (self.type().scalarType() == dtype ||
(self.is_cuda() && self.type().scalarType() == kHalf && dtype == kFloat)) {
if (self.scalar_type() == dtype ||
(self.is_cuda() && self.scalar_type() == kHalf && dtype == kFloat)) {
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
}
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype));
@ -434,8 +434,9 @@ Tensor& logsumexp_out(Tensor& result, const Tensor& self, DimnameList dims, bool
static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto p = opt_p.value_or(2.0);
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type();
TORCH_CHECK(
@ -458,8 +459,8 @@ static inline Tensor _norm(const Tensor &self, Scalar p) {
if (self.is_sparse()) {
return at::native_norm(self, p);
} else {
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"norm only supports floating-point dtypes");
@ -510,9 +511,9 @@ inline Tensor & _all(Tensor & result, TensorIterator & iter) {
}
Tensor all(const Tensor& self) {
TORCH_CHECK(self.type().backend() == Backend::CPU ||
self.type().backend() == Backend::CUDA, "all only supports CPU AND CUDA "
"backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU ||
self.options().backend() == Backend::CUDA, "all only supports CPU AND CUDA "
"backend, got: ", toString(self.options().backend()));
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
"all only supports torch.uint8 and torch.bool dtypes");
@ -528,9 +529,9 @@ Tensor all(const Tensor& self, int64_t dim, bool keepdim) {
}
Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
TORCH_CHECK(self.type().backend() == Backend::CPU ||
self.type().backend() == Backend::CUDA, "all only supports CPU AND CUDA "
"backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU ||
self.options().backend() == Backend::CUDA, "all only supports CPU AND CUDA "
"backend, got: ", toString(self.options().backend()));
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
"all only supports torch.uint8 and torch.bool dtypes");
dim = maybe_wrap_dim(dim, self.dim());
@ -554,11 +555,11 @@ inline Tensor & _any(Tensor & result, TensorIterator & iter) {
}
Tensor any(const Tensor& self) {
TORCH_CHECK(self.type().backend() == Backend::CPU ||
self.type().backend() == Backend::CUDA ||
self.type().backend() == Backend::SparseCPU ||
self.type().backend() == Backend::SparseCUDA, "any only supports CPU, CUDA, "
"SparseCPU and SparseCUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU ||
self.options().backend() == Backend::CUDA ||
self.options().backend() == Backend::SparseCPU ||
self.options().backend() == Backend::SparseCUDA, "any only supports CPU, CUDA, "
"SparseCPU and SparseCUDA backend, got: ", toString(self.options().backend()));
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
"all only supports torch.uint8 and torch.bool dtypes");
@ -574,9 +575,9 @@ Tensor any(const Tensor& self, int64_t dim, bool keepdim) {
}
Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
TORCH_CHECK(self.type().backend() == Backend::CPU ||
self.type().backend() == Backend::CUDA, "any only supports CPU AND CUDA "
"backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU ||
self.options().backend() == Backend::CUDA, "any only supports CPU AND CUDA "
"backend, got: ", toString(self.options().backend()));
TORCH_CHECK(self.scalar_type() == at::ScalarType::Byte || self.scalar_type() == at::ScalarType::Bool,
"all only supports torch.uint8 and torch.bool dtypes");
dim = maybe_wrap_dim(dim, self.dim());
@ -636,7 +637,7 @@ Tensor& argmax_out(Tensor& result, const Tensor& self, c10::optional<int64_t> di
in = self.reshape({-1});
keepdim = false;
}
if (self.type().backend() != Backend::CPU && self.type().backend() != Backend::CUDA) {
if (self.options().backend() != Backend::CPU && self.options().backend() != Backend::CUDA) {
Tensor ignored = at::empty({0}, self.options());
return std::get<1>(at::max_out(ignored, result, in, dim.value_or(0), keepdim));
}
@ -661,7 +662,7 @@ Tensor& argmin_out(Tensor& result, const Tensor& self, c10::optional<int64_t> di
in = self.reshape({-1});
keepdim = false;
}
if (self.type().backend() != Backend::CPU && self.type().backend() != Backend::CUDA) {
if (self.options().backend() != Backend::CPU && self.options().backend() != Backend::CUDA) {
Tensor ignored = at::empty({0}, self.options());
return std::get<1>(at::min_out(ignored, result, in, dim.value_or(0), keepdim));
}
@ -677,8 +678,8 @@ Tensor argmin(const Tensor& self, c10::optional<int64_t> dim, bool keepdims) {
}
static Tensor &std_var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim, bool take_sqrt) {
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std and var only support CPU AND CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
"std and var only support CPU AND CUDA backend, got: ", toString(self.options().backend()));
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"std and var only support floating-point dtypes");
@ -716,15 +717,13 @@ static Tensor &std_var_out(Tensor &result, const Tensor &self, IntArrayRef dim,
static std::tuple<Tensor&,Tensor&> std_var_mean_out(const char* fname, Tensor &result1, Tensor &result2, const Tensor &self, IntArrayRef dim, bool unbiased, bool keepdim, bool take_sqrt) {
AT_ASSERT(result1.defined() && result2.defined());
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
fname, " only support CPU and CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(at::isFloatingType(self.type().scalarType()) || at::isComplexType(self.scalar_type()),
fname, " only support floating-point dtypes");
TORCH_CHECK(result1.type().scalarType() == result2.type().scalarType(),
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA, fname, " only support CPU AND CUDA backend, got: ", toString(self.options().backend()));
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()), fname, " only support floating-point dtypes");
TORCH_CHECK(result1.scalar_type() == result2.scalar_type(),
"provided by result1 dtype must match dtype of result2. Got ",
toString(result1.type().scalarType()),
toString(result1.scalar_type()),
" and ",
toString(result2.type().scalarType()),
toString(result2.scalar_type()),
".");
if (at::isComplexType(self.scalar_type())){
ScalarType dtype = c10::toValueType(get_dtype(result1, self, {}, true));
@ -805,8 +804,8 @@ std::tuple<Tensor,Tensor> var_mean(const Tensor& self, bool unbiased) {
}
Tensor var(const Tensor& self, bool unbiased) {
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"var only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
"var only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"var only supports floating-point dtypes");
auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits<double>::quiet_NaN());
@ -823,8 +822,8 @@ Tensor &var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbias
}
Tensor std(const Tensor& self, bool unbiased) {
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
"std only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
TORCH_CHECK(at::isFloatingType(self.scalar_type()) || at::isComplexType(self.scalar_type()),
"std only supports floating-point dtypes");
auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits<double>::quiet_NaN());

View File

@ -55,7 +55,7 @@ inline void _reduction_with_indices_allocate_or_resize_output(
}
if (values.defined()) {
TORCH_CHECK(
self.type() == values.type(),
self.options().type_equal(values.options()),
"output values must be of same type as input");
if (!keepdim && values.dim() == self.dim() - 1) {
// unsqueeze to preserve passed in noncontiguous tensor in resize
@ -95,7 +95,7 @@ inline void _allocate_or_resize_output_with_indices(
}
if (values.defined()) {
TORCH_CHECK(
self.type() == values.type(),
self.options().type_equal(values.options()),
"output values must be of same type as input");
values.resize_(result_sizes);
} else {

View File

@ -29,7 +29,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
signal_ndim);
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
"Expected an input tensor of floating types, but got input=",
self.type(), self.sizes());
self.toString(), self.sizes());
auto signal_tensor_ndim = signal_ndim + static_cast<int64_t>(complex_input); // add complex dim
if (self.dim() < signal_tensor_ndim) {
@ -39,7 +39,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
if (complex_input) {
ss << " (complex input adds an extra dimension)";
}
ss << ", but got input=" << self.type() << self.sizes();
ss << ", but got input=" << self.toString() << self.sizes();
AT_ERROR(ss.str());
}
@ -65,7 +65,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
TORCH_CHECK(input.size(signal_ndim + 1) == 2,
"Expected an input tensor with a last dimension of size 2 "
"representing real + imaginary components, but got input ",
self.type(), self.sizes());
self.toString(), self.sizes());
}
// build signal_sizes and output_size
@ -101,7 +101,7 @@ static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
TORCH_CHECK(signal_sizes.size() == 0 || signal_sizes[i] == checked_signal_sizes[i],
"Expected given signal_sizes=", signal_sizes," to have same "
"shape with input at signal dimension ", i, ", but got "
"signal_sizes=", signal_sizes, " and input=", self.type(),
"signal_sizes=", signal_sizes, " and input=", self.toString(),
self.sizes());
}
}
@ -177,11 +177,11 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop
const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool normalized, const bool onesided) {
#define REPR(SS) \
SS << "stft(" << self.type() << self.sizes() << ", n_fft=" << n_fft \
SS << "stft(" << self.toString() << self.sizes() << ", n_fft=" << n_fft \
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
<< ", window="; \
if (window.defined()) { \
SS << window.type() << "{" << window.sizes() << "}"; \
SS << window.toString() << "{" << window.sizes() << "}"; \
} else { \
SS << "None"; \
} \

View File

@ -150,8 +150,8 @@ std::tuple<Tensor, Tensor> mode(const Tensor& self, int64_t dim, bool keepdim) {
std::tuple<Tensor &,Tensor &> mode_out(Tensor& values, Tensor& indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"mode only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
"mode only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(values, self, dim, keepdim, "mode")) {
AT_ASSERT(values.dim() == 0);
@ -202,8 +202,8 @@ std::tuple<Tensor, Tensor> max(const Tensor& self, int64_t dim, bool keepdim) {
static std::tuple<Tensor &,Tensor &> max_out_impl(Tensor& max, Tensor& max_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"max only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
"max only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(max, self, dim, keepdim, "max")) {
AT_ASSERT(max.dim() == 0);
@ -262,8 +262,8 @@ std::tuple<Tensor, Tensor> min(const Tensor& self, int64_t dim, bool keepdim) {
static std::tuple<Tensor &,Tensor &> min_out_impl(Tensor& min, Tensor& min_indices,
const Tensor& self, int64_t dim, bool keepdim) {
TORCH_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"min only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
TORCH_CHECK(self.options().backend() == Backend::CPU || self.options().backend() == Backend::CUDA,
"min only supports CPU AND CUDA backend, got: ", toString(self.options().backend()));
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) {
AT_ASSERT(min.dim() == 0);

View File

@ -67,7 +67,7 @@ inline void check_size_nonnegative(IntArrayRef size) {
inline void check_supported_max_int_with_precision(int64_t n, const Tensor& tensor) {
TORCH_CHECK(at::scalar_tensor(n, tensor.options()).defined(),
"n is too large for result tensor type: '", tensor.type().toString(), "'");
"n is too large for result tensor type: '", tensor.toString(), "'");
// Ensure sufficient precision for floating point representation.
switch (tensor.scalar_type()) {

View File

@ -350,7 +350,7 @@ Tensor expand(const Tensor& self, IntArrayRef size, bool implicit) {
// requested by the user, because it is legal to remove implicit expands
// from the graph, but not legal to remove the explicit ones.
TORCH_CHECK(size.size() >= (size_t)self.dim(),
"expand(", self.type(), "{", self.sizes(), "}, size=", size,
"expand(", self.toString(), "{", self.sizes(), "}, size=", size,
"): the number of sizes provided (", size.size(), ") ",
"must be greater or equal to the number of dimensions in the tensor (",
self.dim(), ")");

View File

@ -267,7 +267,7 @@ Tensor& _clamp_min_out_cpu(Tensor& result, const Tensor& self, Scalar min) {
Tensor mvlgamma(const Tensor& self, int64_t p) {
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
"mvlgamma is not implemented for ", self.type());
"mvlgamma is not implemented for ", self.scalar_type());
TORCH_CHECK((self > 0.5 * (p - 1.)).all().item<uint8_t>(),
"Condition for computing multivariate log-gamma not met");
TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1");
@ -278,7 +278,7 @@ Tensor mvlgamma(const Tensor& self, int64_t p) {
Tensor& mvlgamma_(Tensor& self, int64_t p) {
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
"mvlgamma is not implemented for ", self.type());
"mvlgamma is not implemented for ", self.scalar_type());
TORCH_CHECK((self > 0.5 * (p - 1.)).all().item<uint8_t>(),
"Condition for computing multivariate log-gamma not met");
TORCH_CHECK(p >= 1, "p has to be greater than or equal to 1");

View File

@ -238,6 +238,7 @@ class CAFFE2_API Tensor {
return impl_->itemsize();
}
C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().")
DeprecatedTypeProperties & type() const {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
tensorTypeIdToBackend(legacyExtractTypeId(type_set())),

View File

@ -20,7 +20,7 @@ struct Foo {
static void apply(Tensor a, Tensor b) {
scalar_type s = 1;
std::stringstream ss;
ss << "hello, dispatch: " << a.type().toString() << s << "\n";
ss << "hello, dispatch: " << a.toString() << s << "\n";
auto data = (scalar_type*)a.data_ptr();
(void)data;
}
@ -110,7 +110,7 @@ TEST(TestScalar, TestScalar) {
scalar_t s = 1;
std::stringstream ss;
ASSERT_NO_THROW(
ss << "hello, dispatch" << x.type().toString() << s << "\n");
ss << "hello, dispatch" << x.toString() << s << "\n");
auto data = (scalar_t*)x.data_ptr();
(void)data;
});

View File

@ -292,6 +292,15 @@ struct C10_API TensorOptions {
return has_pinned_memory_;
}
/// Returns if the layout is sparse
bool is_sparse() const {
return layout_ == c10::Layout::Sparse;
}
// For compatibility with legacy tensor.type() comparisons
bool type_equal(const TensorOptions& other) const {
return backend() == other.backend() && typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
}
/// Returns the `pinned_memory` property of the `TensorOptions`, or
/// `c10::nullopt` if `pinned_memory` is not specified.
@ -538,6 +547,12 @@ inline TensorOptions dtype() {
return dtype(caffe2::TypeMeta::Make<T>());
}
inline std::string toString(const TensorOptions options) {
std::ostringstream stream;
stream << options;
return stream.str();
}
// This is intended to be a centralized location by which we can determine
// what an appropriate TensorTypeId for a tensor is.
//

View File

@ -21,7 +21,7 @@ using namespace torch::test;
ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type()); \
ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \
ASSERT_EQ(tensor.scalar_type(), (type_)); \
ASSERT_TRUE(tensor.type().layout() == (layout_))
ASSERT_TRUE(tensor.options().layout() == (layout_))
TEST(TensorOptionsTest, DefaultsToTheRightValues) {
TensorOptions options;

View File

@ -29,7 +29,7 @@ at::Device CUDADevice(DeviceIndex index) {
ASSERT_EQ(tensor.device().type(), Device((device_), (index_)).type()); \
ASSERT_EQ(tensor.device().index(), Device((device_), (index_)).index()); \
ASSERT_EQ(tensor.scalar_type(), (type_)); \
ASSERT_TRUE(tensor.type().layout() == (layout_))
ASSERT_TRUE(tensor.options().layout() == (layout_))
TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) {
auto options = CUDA(kFloat).options();

View File

@ -6,7 +6,7 @@ namespace torch {
namespace jit {
int device(const autograd::Variable& v) {
return v.type().is_cuda() ? v.get_device() : -1;
return v.device().is_cuda() ? v.get_device() : -1;
}
bool isEqual(at::IntArrayRef lhs, at::IntArrayRef rhs) {
@ -29,18 +29,18 @@ bool isEqual(const ArgumentInfo& ti, const autograd::Variable& v) {
ti.type() == v.scalar_type() && ti.dim() == v.dim();
}
autograd::Variable var(at::DeprecatedTypeProperties& t, at::IntArrayRef sizes, bool requires_grad) {
return autograd::make_variable(at::rand(sizes, t.options()), requires_grad);
autograd::Variable var(at::TensorOptions t, at::IntArrayRef sizes, bool requires_grad) {
return autograd::make_variable(at::rand(sizes, t), requires_grad);
}
autograd::Variable undef() {
return autograd::Variable();
}
void testCompleteArgumentSpec() {
auto& CF = at::CPU(at::kFloat);
auto& CD = at::CPU(at::kDouble);
auto& GF = at::CUDA(at::kFloat);
auto& GD = at::CUDA(at::kDouble);
auto const CF = at::CPU(at::kFloat);
auto const CD = at::CPU(at::kDouble);
auto const GF = at::CUDA(at::kFloat);
auto const GD = at::CUDA(at::kDouble);
auto list = createStack({var(CF, {1}, true),
var(CD, {1, 2}, false),

View File

@ -6,8 +6,8 @@
void sigmoid_add_cuda(const float* x, const float* y, float* output, int size);
torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) {
TORCH_CHECK(x.type().is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(y.type().is_cuda(), "y must be a CUDA tensor");
TORCH_CHECK(x.device().is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(y.device().is_cuda(), "y must be a CUDA tensor");
auto output = torch::zeros_like(x);
sigmoid_add_cuda(
x.data_ptr<float>(), y.data_ptr<float>(), output.data_ptr<float>(), output.numel());

View File

@ -656,7 +656,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
}
python_binding_arguments.append(dtype_arg)
if is_factory_function or is_like_or_new_function_with_options:
py_default_layout = '*torch::getLayout(self.type().backend())' if is_like_or_new_function_with_options else None
py_default_layout = '*torch::getLayout(self.options().backend())' if is_like_or_new_function_with_options else None
layout_arg = {
'default': 'torch.strided',
'dynamic_type': 'Layout',

View File

@ -32,17 +32,17 @@ inline std::vector<Tensor> unpack_list(at::ArrayRef<SavedVariable> xs) {
}
struct TypeAndSize {
TypeAndSize() : type(nullptr) {}
TypeAndSize() : options(at::TensorOptions()) {}
/* implicit */
TypeAndSize(const Tensor & t)
: sizes(t.sizes().vec())
, type(&t.type()) {}
, options(t.options()) {}
Tensor zeros() { return at::zeros(sizes, *type); }
Tensor zeros() { return at::zeros(sizes, options); }
private:
std::vector<int64_t> sizes;
at::DeprecatedTypeProperties* type;
at::TensorOptions options;
};
${autograd_function_declarations}

View File

@ -783,7 +783,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.isNone(0)) {
return THPUtils_packString(torch::utils::type_to_string(self_.type()));
return THPUtils_packString(torch::utils::options_to_string(self_.options()));
}
auto obj = r.pyobject(0);
auto opt_memory_format = r.memoryformatOptional(2);
@ -807,9 +807,9 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
if (is_dtype) {
scalar_type = r.scalartype(0);
} else {
at::DeprecatedTypeProperties* type = torch::utils::type_from_string(type_name);
scalar_type = type->scalarType();
auto device_type = backendToDeviceType(type->backend());
at::TensorOptions options = torch::utils::options_from_string(type_name);
scalar_type = at::typeMetaToScalarType(options.dtype());
auto device_type = options.device().type();
if (device_type != device.type()) {
device = at::Device(device_type);
}

View File

@ -102,7 +102,7 @@ static PyObject * THPGenerator_setState(THPGenerator *self, PyObject *_new_state
}
auto& tensor = ((THPVariable*)_new_state)->cdata;
if (tensor.layout() != kStrided || tensor.device().type() != kCPU || tensor.scalar_type() != kByte) {
auto type_name = torch::utils::type_to_string(tensor.type());
auto type_name = torch::utils::options_to_string(tensor.options());
throw TypeError("expected a torch.ByteTensor, but got %s", type_name.c_str());
}
if (self->cdata->device().type() == at::kCPU) {

View File

@ -101,7 +101,7 @@ void set_data(const Tensor & self, const Tensor & new_data) {
const auto prior_device = prior_accumulator->input_metadata(0).device();
const auto new_device = new_data.device();
if (new_data.type() != self.type() || prior_device != new_device) {
if (!new_data.options().type_equal(self.options()) || prior_device != new_device) {
autograd_meta->grad_accumulator_.reset();
}
}

View File

@ -99,7 +99,7 @@ variable_list _wrap_outputs(const variable_list &input_vars,
}
void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) {
if (original.type() != result.type()) {
if (!original.options().type_equal(result.options())) {
std::stringstream ss;
ss << "hook '" << hook_name << "' has changed the type of value (";
ss << "was " << original.toString() << " got ";

View File

@ -411,11 +411,10 @@ static variable_list call_post_hooks(Node& fn, variable_list outputs, const vari
return outputs;
}
static bool is_compatible_type(const at::DeprecatedTypeProperties& expected, const at::DeprecatedTypeProperties& actual) {
static bool is_compatible_type(const at::TensorOptions& expected, const at::TensorOptions& actual) {
// Types are compatible if they exactly match or if the gradient is a sparse
// version of the expected type.
return expected == actual || (actual.is_sparse() &&
expected == actual.toBackend(toDense(actual.backend())));
return expected.type_equal(actual) || (actual.is_sparse() && expected.device().type() == actual.device().type());
}
void validate_outputs(
@ -451,14 +450,14 @@ void validate_outputs(
}
grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
}
TORCH_CHECK(isFloatingType(grads[i].type().scalarType()));
if (metadata.type().scalarType() != grads[i].type().scalarType()) {
grads[i] = grads[i].to(metadata.type().scalarType());
TORCH_CHECK(isFloatingType(grads[i].scalar_type()));
if (c10::typeMetaToScalarType(metadata.options().dtype()) != grads[i].scalar_type()) {
grads[i] = grads[i].to(c10::typeMetaToScalarType(metadata.options().dtype()));
}
if (!is_compatible_type(metadata.type(), grads[i].type())) {
if (!is_compatible_type(metadata.options(), grads[i].options())) {
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected type ";
ss << metadata.type() << " but got " << grads[i].type();
ss << metadata.options() << " but got " << grads[i].options();
AT_ERROR(format_error(ss.str()));
}
auto output_device = output.device();

View File

@ -137,11 +137,11 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
/// Adds the type and shape metadata for a new input. Returns the index of
/// of the new input.
uint32_t add_input_metadata(
const at::DeprecatedTypeProperties& type
const at::TensorOptions& options
, at::IntArrayRef shape
, at::Device device) noexcept {
uint32_t input_nr = input_metadata_.size();
input_metadata_.emplace_back(type, shape, device);
input_metadata_.emplace_back(options, shape, device);
return input_nr;
}

View File

@ -77,7 +77,7 @@ variable_list Gather::apply(variable_list&& inputs) {
TORCH_CHECK(
input.is_cuda(),
"All inputs to Gather must be CUDA tensors, got ",
input.type());
input.toString());
if (input.dim() > 0) {
all_are_zero_dim = false;
}

View File

@ -20,21 +20,16 @@ namespace torch { namespace autograd {
struct InputMetadata {
InputMetadata() = default;
InputMetadata(const at::DeprecatedTypeProperties& type, at::IntArrayRef shape, at::Device device)
: type_{&type}, shape_{shape}, device_{device} {
InputMetadata(const at::TensorOptions options, at::IntArrayRef shape, at::Device device)
: options_{options}, shape_{shape}, device_{device} {
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
}
InputMetadata(const at::Tensor& t)
: InputMetadata(t.type(), t.sizes(), t.device()) { }
: InputMetadata(t.options(), t.sizes(), t.device()) { }
bool is_valid() const {
return type_ != nullptr;
}
const at::DeprecatedTypeProperties& type() const {
AT_ASSERT(type_);
return *type_;
const at::TensorOptions options() const {
return options_;
}
at::IntArrayRef shape() const {
@ -50,11 +45,11 @@ struct InputMetadata {
}
at::Tensor zeros_like() const {
return at::zeros(shape_, type_->options(device_));
return at::zeros(shape_, options_);
}
private:
const at::DeprecatedTypeProperties* type_ = nullptr;
const at::TensorOptions options_;
at::DimVector shape_;
at::Device device_ = at::kCPU;
c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device_);

View File

@ -275,7 +275,7 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused)
bool gradIsSparse = (var.dtype() == grad.dtype() &&
var.device().type() == grad.device().type() &&
grad.layout() == kSparse);
THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse,
THPUtils_assertRet(-1, grad.options().type_equal(var.options()) || gradIsSparse,
"assigned grad has data of a different type");
if (var.is_cuda()) {
THPUtils_assertRet(-1, grad.get_device() == var.get_device(),
@ -487,7 +487,7 @@ static PyObject *THPVariable_dtype(THPVariable *self, void *unused)
static PyObject * THPVariable_layout(THPVariable* self, void *unused) {
HANDLE_TH_ERRORS
auto& self_ = self->cdata;
return torch::autograd::utils::wrap(torch::getLayout(self_.type().backend()));
return torch::autograd::utils::wrap(torch::getLayout(self_.options().backend()));
END_HANDLE_TH_ERRORS
}

View File

@ -149,7 +149,7 @@ static Variable valueToTensor(c10::TensorOptions options, PyObject* value) {
throw TypeError(
"can't assign a %s to a %s",
Py_TYPE(value)->tp_name,
torch::utils::type_to_string(getDeprecatedTypeProperties(options.backend(), typeMetaToScalarType(options.dtype()))).c_str());
torch::utils::options_to_string(options).c_str());
}
static Variable boolToIndexingTensor(const Variable& self, bool value) {

View File

@ -57,7 +57,7 @@ Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
if (saved_version_ != version_counter_.current_version()) {
std::stringstream message;
message << "one of the variables needed for gradient computation has been "
"modified by an inplace operation: [" << data_.type().toString() << " "
"modified by an inplace operation: [" << data_.toString() << " "
<< data_.sizes() << "]";
if (grad_fn) {
message << ", which is output " << output_nr_

View File

@ -331,7 +331,7 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
fn->storage_offset = self.storage_offset();
fn->set_next_edges(torch::autograd::collect_next_edges(diff_view_meta->base_));
fn->add_input_metadata(
diff_view_meta->base_.type()
diff_view_meta->base_.options()
, self.sizes() // Note: sizes(), not base_.sizes(), is intentional
, diff_view_meta->base_.device());
diff_view_meta->grad_fn_ = std::move(fn);

View File

@ -193,11 +193,11 @@ void Reducer::mark_variable_ready_dense(VariableIndex index) {
if (grad.defined()) {
// Ensure that the gradient type matches the bucket type.
AT_ASSERTM(
grad.type() == bucket_view.type(),
grad.options().type_equal(bucket_view.options()),
"Expected ",
bucket_view.type(),
bucket_view.toString(),
", got ",
grad.type());
grad.toString());
// Assert that the grad tensor and the bucket don't share storage.
// If they did, we could avoid the copy altogether.
// The reason for not doing this is that existing code calls

View File

@ -277,7 +277,7 @@ struct DifferentiableGraphBackward : public autograd::Node {
// NB: since our requires_grad setting is only a heuristic we might end
// up wanting to differentiate through integral tensors, which is
// generally a hard error in autograd.
if (at::isFloatingType(output.type().scalarType())) {
if (at::isFloatingType(output.scalar_type())) {
autograd::create_gradient_edge(output, shared_from_this());
output.set_requires_grad(true);
} else {

View File

@ -16,7 +16,7 @@ namespace jit {
namespace {
bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
return lhs.type() == rhs.type() && lhs.equal(rhs);
return lhs.options().type_equal(rhs.options()) && lhs.equal(rhs);
}
bool tensorListEqual(

View File

@ -296,7 +296,7 @@ struct PythonPrintImpl {
// because it doesn't hash any information about the tensors.
// We will probably need to optimize this at some point using hashing.
for (size_t i = 0; i < tensor_table_.size(); ++i) {
if (t.type() == tensor_table_[i].type() && t.equal(tensor_table_[i])) {
if (t.options().type_equal(tensor_table_[i].options()) && t.equal(tensor_table_[i])) {
return i;
}
}

View File

@ -54,7 +54,7 @@ static void recursive_apply(IntArrayRef sizes, ScalarType scalarType, int64_t di
}
Tensor & apply_(Tensor & self, PyObject* fn) {
if (self.type().backend() != Backend::CPU) {
if (self.options().backend() != Backend::CPU) {
throw TypeError("apply_ is only implemented on CPU tensors");
}
auto scalarType = self.scalar_type();
@ -63,12 +63,12 @@ Tensor & apply_(Tensor & self, PyObject* fn) {
}
Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) {
if (self.type().backend() != Backend::CPU) {
if (self.options().backend() != Backend::CPU) {
throw TypeError("map_ is only implemented on CPU tensors");
}
if (other_.type() != self.type()) {
if (!other_.options().type_equal(self.options())) {
throw TypeError("map_: expected %s for 'other' (got %s)",
self.type().toString().c_str(), other_.type().toString().c_str());
self.toString().c_str(), other_.toString().c_str());
}
Tensor other;
std::tie(other) = expand_inplace(self, other_, "map_");
@ -78,16 +78,16 @@ Tensor & map_(Tensor & self, const Tensor & other_, PyObject* fn) {
}
Tensor & map2_(Tensor & self, const Tensor & x_, const Tensor & y_, PyObject* fn) {
if (self.type().backend() != Backend::CPU || x_.type().backend() != Backend::CPU || y_.type().backend() != Backend::CPU) {
if (self.options().backend() != Backend::CPU || x_.options().backend() != Backend::CPU || y_.options().backend() != Backend::CPU) {
throw TypeError("map2_ is only implemented on CPU tensors");
}
if (x_.type() != self.type()) {
if (!x_.options().type_equal(self.options())) {
throw TypeError("map2_: expected %s for argument 'x' (got %s)",
self.type().toString().c_str(), x_.type().toString().c_str());
self.toString().c_str(), x_.toString().c_str());
}
if (y_.type() != self.type()) {
if (!y_.options().type_equal(self.options())) {
throw TypeError("map2_: expected %s for argument 'y' (got %s)",
self.type().toString().c_str(), y_.type().toString().c_str());
self.toString().c_str(), y_.toString().c_str());
}
Tensor other1, other2;
std::tie(other1, other2) = expand_inplace(self, x_, y_, "map2_");

View File

@ -30,7 +30,7 @@ static PyObject* recursive_to_list(
PyObject* tensor_to_list(const Tensor& tensor) {
Tensor data = tensor;
if (data.type().backend() != Backend::CPU) {
if (data.options().backend() != Backend::CPU) {
pybind11::gil_scoped_release no_gil;
data = data.toBackend(Backend::CPU);
}

View File

@ -81,8 +81,8 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
"can't convert sparse tensor to numpy. Use Tensor.to_dense() to "
"convert to a dense tensor first.");
}
if (tensor.type().backend() != Backend::CPU) {
throw TypeError("NumPy conversion for %s is not supported", tensor.type().toString().c_str());
if (tensor.options().backend() != Backend::CPU) {
throw TypeError("NumPy conversion for %s is not supported", tensor.toString().c_str());
}
if (tensor.requires_grad()) {
throw std::runtime_error(

View File

@ -28,13 +28,19 @@ static const char* backend_to_string(const at::Backend& backend) {
}
}
std::string options_to_string(const at::TensorOptions options) {
std::ostringstream ss;
ss << backend_to_string(options.backend()) << "." << toString(at::typeMetaToScalarType(options.dtype())) << "Tensor";
return ss.str();
}
std::string type_to_string(const at::DeprecatedTypeProperties& type) {
std::ostringstream ss;
ss << backend_to_string(type.backend()) << "." << toString(type.scalarType()) << "Tensor";
return ss.str();
}
at::DeprecatedTypeProperties* type_from_string(const std::string& str) {
at::TensorOptions options_from_string(const std::string& str) {
static std::string cuda_prefix("torch.cuda.");
static std::once_flag cpu_once;
static std::once_flag cuda_once;
@ -46,7 +52,7 @@ at::DeprecatedTypeProperties* type_from_string(const std::string& str) {
if (str == "torch.Tensor") {
auto backend = tensorTypeIdToBackend(torch::tensors::get_default_tensor_type_id());
auto scalar_type = torch::tensors::get_default_scalar_type();
return &getDeprecatedTypeProperties(backend, scalar_type);
return getDeprecatedTypeProperties(backend, scalar_type).options();
}
if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin()).first == cuda_prefix.end()) {
@ -70,7 +76,7 @@ at::DeprecatedTypeProperties* type_from_string(const std::string& str) {
if (it == map->end()) {
throw ValueError("invalid type: '%s'", str.c_str());
}
return it->second;
return it->second->options();
}
std::vector<std::pair<Backend, ScalarType>> all_declared_types() {

View File

@ -6,8 +6,9 @@
namespace torch { namespace utils {
std::string options_to_string(const at::TensorOptions options);
std::string type_to_string(const at::DeprecatedTypeProperties& type);
at::DeprecatedTypeProperties* type_from_string(const std::string& str);
at::TensorOptions options_from_string(const std::string& str);
// return a vector of all "declared" types, even those that weren't compiled
std::vector<std::pair<at::Backend, at::ScalarType>> all_declared_types();

View File

@ -1297,7 +1297,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allreduce_coalesced(
// tensors must have the same device, layout and type.
assertLayoutMatch(invalidArgument, tensors);
if (!std::all_of(tensors.begin(), tensors.end(), [&](at::Tensor& t) {
return t.type() == tensors[0].type();
return t.options().type_equal(tensors[0].options());
})) {
invalidArgument("tensors must all have the same type");
}
@ -1670,11 +1670,11 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allgather(
assertDense(invalidArgument, inputs);
// Expect all input/output tensors to have the same type and sizes
const auto& type = inputs[0].type();
const auto& options = inputs[0].options();
const auto& sizes = inputs[0].sizes();
assertTypeAndSizesMatch(invalidArgument, inputs, type, sizes);
assertTypeAndSizesMatch(invalidArgument, inputs, options, sizes);
for (size_t i = 0; i < outputs.size(); i++) {
assertTypeAndSizesMatch(invalidArgument, outputs[i], type, sizes);
assertTypeAndSizesMatch(invalidArgument, outputs[i], options, sizes);
}
const auto& device = inputs[0].device();
@ -1807,11 +1807,11 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allgather_coalesced(
" (expected length " + toString(expected) + ", got " +
toString(actual) + ")");
}
if (input_list[i].type() != output_list[i].type()) {
if (!input_list[i].options().type_equal(output_list[i].options())) {
invalidArgument(
"invalid tensor type at index " + std::to_string(i) +
" (expected " + input_list[i].type().toString() + ", got " +
output_list[i].type().toString() + ")");
" (expected " + input_list[i].toString() + ", got " +
output_list[i].toString() + ")");
}
}
}
@ -1992,9 +1992,9 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::gather(
invalidArgument(ss.str());
}
const auto& type = inputs[0].type();
const auto& options = inputs[0].options();
const auto& sizes = inputs[0].sizes();
assertTypeAndSizesMatch(invalidArgument, outputs[0], type, sizes);
assertTypeAndSizesMatch(invalidArgument, outputs[0], options, sizes);
} else {
if (outputs.size() != 0) {
invalidArgument("requires empty output on non-root");
@ -2178,9 +2178,9 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::scatter(
<< ", same as size of the process group.";
invalidArgument(ss.str());
}
const auto& type = outputs[0].type();
const auto& options = outputs[0].options();
const auto& sizes = outputs[0].sizes();
assertTypeAndSizesMatch(invalidArgument, inputs[0], type, sizes);
assertTypeAndSizesMatch(invalidArgument, inputs[0], options, sizes);
} else {
if (inputs.size() != 0) {
invalidArgument("requires empty input on non-root");

View File

@ -43,9 +43,9 @@ inline void assertSameType(
const at::DeprecatedTypeProperties& type,
const std::vector<at::Tensor>& tensors) {
for (size_t i = 0; i < tensors.size(); i++) {
if (tensors[i].type() != type) {
if (!tensors[i].options().type_equal(type.options())) {
const std::string expected = type.toString();
const std::string actual = tensors[i].type().toString();
const std::string actual = tensors[i].toString();
throw std::invalid_argument(
"mixed types (" + expected + " and " + actual + ")");
}
@ -72,12 +72,12 @@ inline void assertSameSizeAndType(const std::vector<at::Tensor>& tensors) {
}
// Ensure all tensors have identical type and shape
auto type = tensors[0].type();
auto options = tensors[0].options();
auto sizes = tensors[0].sizes();
for (size_t i = 1; i < tensors.size(); i++) {
if (tensors[i].type() != type) {
const std::string expected = type.toString();
const std::string actual = tensors[i].type().toString();
if (!tensors[i].options().type_equal(options)) {
const auto expected = toString(options);
const auto actual = toString(tensors[i].options());
throw std::invalid_argument(
"argument contains mixed types (" + expected + " and " + actual +
")");
@ -97,12 +97,24 @@ inline void assertTypeMatch(
const at::DeprecatedTypeProperties& type,
const at::ArrayRef<at::Tensor>& tensors,
size_t index) {
if (tensors[index].type() != type) {
if (!tensors[index].options().type_equal(type.options())) {
fn("invalid tensor type at index " + std::to_string(index) + " (expected " +
type.toString() + ", got " + tensors[index].type().toString() + ")");
type.toString() + ", got " + tensors[index].toString() + ")");
}
}
inline void assertTypeMatch(
std::function<void(const std::string&)> fn,
const at::TensorOptions& options,
const at::ArrayRef<at::Tensor>& tensors,
size_t index) {
if (!tensors[index].options().type_equal(options)) {
fn("invalid tensor type at index " + std::to_string(index) + " (expected " +
toString(options) + ", got " + toString(tensors[index].options()) + ")");
}
}
inline void assertSizesMatch(
std::function<void(const std::string&)> fn,
const at::IntArrayRef& sizes,
@ -228,12 +240,23 @@ inline void assertTypeAndSizesMatch(
}
}
inline void assertTypeAndSizesMatch(
std::function<void(const std::string&)> fn,
const at::ArrayRef<at::Tensor>& tensors,
const at::TensorOptions& options,
const at::IntArrayRef& sizes) {
for (size_t i = 0; i < tensors.size(); i++) {
assertTypeMatch(fn, options, tensors, i);
assertSizesMatch(fn, sizes, tensors, i);
}
}
inline void assertTypeAndSizesMatch(
std::function<void(const std::string&)> fn,
const at::ArrayRef<at::Tensor>& tensors) {
const auto& type = tensors[0].type();
const auto& options = tensors[0].options();
const auto sizes = tensors[0].sizes();
assertTypeAndSizesMatch(fn, tensors.slice(1), type, sizes);
assertTypeAndSizesMatch(fn, tensors.slice(1), options, sizes);
}
// Copied from ATen/core/functional.h.
@ -303,7 +326,7 @@ inline std::vector<std::vector<int64_t>> getSizes(
inline std::vector<int> getDevices(const std::vector<at::Tensor>& tensors) {
std::vector<int> devices(tensors.size(), -1);
if (tensors[0].type().is_cuda()) {
if (tensors[0].device().is_cuda()) {
for (size_t i = 0; i < tensors.size(); i++) {
devices[i] = tensors[i].storage().device().index();
}