mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Store ScalarType and Backend instead of Type in TensorIterator
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17601 Reviewed By: ezyang Differential Revision: D14274754 fbshipit-source-id: b08880ae586b6ae57d4c0bbeb203796d087926c4
This commit is contained in:
committed by
Facebook Github Bot
parent
c705d9eb1e
commit
f5741eb855
@ -407,7 +407,7 @@ static AdvancedIndex make_info(Tensor self, TensorList orig) {
|
||||
static std::unique_ptr<TensorIterator> make_index_iterator(const AdvancedIndex& info) {
|
||||
auto builder = TensorIterator::Builder();
|
||||
builder.dont_compute_common_dtype();
|
||||
builder.add_output(Tensor(), &info.src.dispatch_type());
|
||||
builder.add_output(Tensor(), info.src.type().backend(), info.src.scalar_type());
|
||||
builder.add_input(info.src);
|
||||
for (auto& index : info.indices) {
|
||||
builder.add_input(index);
|
||||
@ -424,7 +424,7 @@ static std::unique_ptr<TensorIterator> make_index_put_iterator(const AdvancedInd
|
||||
builder.dont_compute_common_dtype();
|
||||
builder.dont_resize_outputs();
|
||||
builder.add_output(info.src);
|
||||
builder.add_input(value, &info.src.dispatch_type());
|
||||
builder.add_input(value, info.src.type().backend(), info.src.scalar_type());
|
||||
for (auto& index : info.indices) {
|
||||
builder.add_input(index);
|
||||
}
|
||||
|
@ -87,53 +87,53 @@ compute_result_type(at::ArrayRef<OperandInfo> operands, const F& predicate) {
|
||||
void TensorIterator::compute_types() {
|
||||
bool missing_dtypes = false;
|
||||
for (auto& op : operands_) {
|
||||
if (!op.tensor.defined() && !op.type) {
|
||||
if (!op.tensor.defined() && !op.is_type_defined()) {
|
||||
missing_dtypes = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (missing_dtypes || compute_common_dtype_) {
|
||||
auto& type = compute_common_type();
|
||||
ScalarType common_dtype;
|
||||
Backend common_backend;
|
||||
std::tie(common_backend, common_dtype) = compute_common_type();
|
||||
for (auto& op : operands_) {
|
||||
auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.scalar_type());
|
||||
if (!op.type) {
|
||||
op.type = &type;
|
||||
} else if (compute_common_dtype_ && op.type != &type) {
|
||||
if (!op.is_type_defined()) {
|
||||
op.set_type(common_backend, common_dtype);
|
||||
} else if (compute_common_dtype_ && !op.is_type_equal(common_backend, common_dtype)) {
|
||||
if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 &&
|
||||
type.device_type() == kCUDA && op_tensor_type.device_type() == kCPU) {
|
||||
common_backend == Backend::CUDA && op.tensor.type().backend() == Backend::CPU) {
|
||||
// don't cast CPU scalars in CUDA ops that directly support them
|
||||
op.type = &op_tensor_type;
|
||||
op.set_type(op.tensor.type().backend(), op.tensor.scalar_type());
|
||||
} else if (promote_gpu_output_dtypes_ && op.tensor.defined() &&
|
||||
!op.is_output && op_tensor_type.scalarType() == kHalf &&
|
||||
type.scalarType() == kFloat && type.device_type() == kCUDA &&
|
||||
op_tensor_type.device_type() == kCUDA) {
|
||||
!op.is_output && op.tensor.scalar_type() == kHalf &&
|
||||
common_dtype == kFloat && common_backend == Backend::CUDA &&
|
||||
op.tensor.type().backend() == Backend::CUDA) {
|
||||
// allow input tensor type upcasting for fp16 to fp32 in fused kernel
|
||||
// on GPU
|
||||
op.type = &op_tensor_type;
|
||||
op.set_type(op.tensor.type().backend(), op.tensor.scalar_type());
|
||||
} else {
|
||||
op.type = &type;
|
||||
op.set_type(common_backend, common_dtype);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& op : operands_) {
|
||||
auto& op_tensor_type = at::globalContext().getNonVariableType(op.tensor.type().backend(), op.tensor.scalar_type());
|
||||
if (op.tensor.defined() && op_tensor_type != *op.type) {
|
||||
if (op.tensor.defined() && !op.is_type_equal(op.tensor.type().backend(), op.tensor.scalar_type())) {
|
||||
if (op.is_output) {
|
||||
AT_ERROR("output with type ", op_tensor_type.toString(),
|
||||
" doesn't match the desired type ", op.type->toString());
|
||||
AT_ERROR("output with backend ", toString(op.tensor.type().backend()), " and dtype ", toString(op.tensor.scalar_type()),
|
||||
" doesn't match the desired backend ", toString(op.backend), " and dtype ", toString(op.dtype));
|
||||
} else if (op.tensor.dim() == 0) {
|
||||
op.tensor = op.tensor.to(*op.type);
|
||||
op.tensor = op.tensor.to(op.options());
|
||||
} else {
|
||||
AT_ERROR("expected type ", op.type->toString(), " but got ",
|
||||
op_tensor_type.toString());
|
||||
AT_ERROR("expected backend ", toString(op.backend), " and dtype ", toString(op.dtype),
|
||||
" but got backend ", toString(op.tensor.type().backend()), " and dtype ", toString(op.tensor.scalar_type()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Type& TensorIterator::compute_common_type() {
|
||||
std::pair<Backend, ScalarType> TensorIterator::compute_common_type() {
|
||||
// See [Result type computation] in TensorIterator.h
|
||||
auto result_type = ScalarType::Undefined;
|
||||
auto backend = Backend::Undefined;
|
||||
@ -154,7 +154,7 @@ Type& TensorIterator::compute_common_type() {
|
||||
AT_ASSERT(result_type != ScalarType::Undefined);
|
||||
AT_ASSERT(backend != Backend::Undefined);
|
||||
|
||||
return at::globalContext().getNonVariableType(backend, result_type);
|
||||
return std::make_pair(backend, result_type);
|
||||
}
|
||||
|
||||
DimVector TensorIterator::compatible_stride(int element_size) const {
|
||||
@ -182,8 +182,8 @@ void TensorIterator::allocate_outputs() {
|
||||
for (int i = 0; i < num_outputs_; i++) {
|
||||
auto& op = operands_[i];
|
||||
if (!op.tensor.defined()) {
|
||||
AT_ASSERTM(op.type, "no type for operand", i);
|
||||
int element_size = op.type->typeMeta().itemsize();
|
||||
AT_ASSERTM(op.is_type_defined(), "no type for operand", i);
|
||||
int element_size = elementSize(op.dtype);
|
||||
op.stride_bytes = compatible_stride(element_size);
|
||||
|
||||
auto tensor_shape = invert_perm(shape_);
|
||||
@ -191,7 +191,7 @@ void TensorIterator::allocate_outputs() {
|
||||
for (int dim = 0; dim < ndim(); dim++) {
|
||||
tensor_stride[dim] /= element_size;
|
||||
}
|
||||
op.tensor = at::empty_strided(tensor_shape, tensor_stride, op.type->options());
|
||||
op.tensor = at::empty_strided(tensor_shape, tensor_stride, op.options());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -420,7 +420,7 @@ bool TensorIterator::is_scalar(int arg) const {
|
||||
}
|
||||
|
||||
bool TensorIterator::is_cpu_scalar(int arg) const {
|
||||
return is_scalar(arg) && operands_[arg].tensor.type().device_type() == kCPU;
|
||||
return is_scalar(arg) && device_type(arg) == kCPU;
|
||||
}
|
||||
|
||||
void* TensorIterator::data_ptr(int arg) const {
|
||||
|
@ -66,10 +66,11 @@ struct DimCounter {
|
||||
};
|
||||
struct CAFFE2_API OperandInfo {
|
||||
OperandInfo() {}
|
||||
OperandInfo(const Tensor& t, const Type* type=nullptr)
|
||||
: tensor(t), type(const_cast<Type*>(type)) {
|
||||
if (t.defined() && !type) {
|
||||
this->type = &t.dispatch_type();
|
||||
explicit OperandInfo(const Tensor& t, const Backend backend=Backend::Undefined, const ScalarType dtype=ScalarType::Undefined)
|
||||
: tensor(t), backend(backend), dtype(dtype) {
|
||||
if (t.defined() && (backend == Backend::Undefined || dtype == ScalarType::Undefined)) {
|
||||
this->backend = t.type().backend();
|
||||
this->dtype = t.scalar_type();
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,7 +86,25 @@ struct CAFFE2_API OperandInfo {
|
||||
/// input should be converted to this type if necessary. For outputs, this
|
||||
/// specifies which type to allocate. Note that there is very limited support
|
||||
/// for type conversions currently: they are only allowed for zero-dim tensors.
|
||||
Type* type = nullptr;
|
||||
Backend backend = Backend::Undefined;
|
||||
ScalarType dtype = ScalarType::Undefined;
|
||||
|
||||
bool is_type_defined() {
|
||||
return dtype != ScalarType::Undefined && backend != Backend::Undefined;
|
||||
}
|
||||
|
||||
bool is_type_equal(Backend b, ScalarType s) {
|
||||
return dtype == s && backend == b;
|
||||
}
|
||||
|
||||
void set_type(Backend b, ScalarType s) {
|
||||
dtype = s;
|
||||
backend = b;
|
||||
}
|
||||
|
||||
TensorOptions options() {
|
||||
return TensorOptions(backendToDeviceType(backend)).dtype(dtype);
|
||||
}
|
||||
|
||||
/// The data pointer. This may be different from tensor.data_ptr() if the
|
||||
/// iterator is split.
|
||||
@ -148,13 +167,9 @@ struct CAFFE2_API TensorIterator {
|
||||
/// Accessors for each operand
|
||||
IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; }
|
||||
void* data_ptr(int arg) const;
|
||||
const Type& type(int arg=0) const {
|
||||
AT_ASSERT(operands_[arg].type);
|
||||
return *operands_[arg].type;
|
||||
}
|
||||
ScalarType dtype(int arg=0) const { return type(arg).scalarType(); }
|
||||
DeviceType device_type(int arg=0) const { return type(arg).device_type(); }
|
||||
int64_t element_size(int arg) const { return type(arg).typeMeta().itemsize(); }
|
||||
ScalarType dtype(int arg=0) const { return operands_[arg].dtype; }
|
||||
DeviceType device_type(int arg=0) const { return backendToDeviceType(operands_[arg].backend); }
|
||||
int64_t element_size(int arg) const { return elementSize(dtype(arg)); }
|
||||
bool is_scalar(int arg) const;
|
||||
bool is_cpu_scalar(int arg) const;
|
||||
|
||||
@ -237,7 +252,7 @@ protected:
|
||||
void reorder_dimensions();
|
||||
void permute_dimensions(IntArrayRef perm);
|
||||
void compute_types();
|
||||
Type& compute_common_type();
|
||||
std::pair<Backend, ScalarType> compute_common_type();
|
||||
void allocate_outputs();
|
||||
void coalesce_dimensions();
|
||||
|
||||
@ -261,13 +276,13 @@ struct TensorIterator::Builder {
|
||||
|
||||
Builder() : iter_(new TensorIterator()) {};
|
||||
|
||||
void add_output(const Tensor& output, const Type* type=nullptr) {
|
||||
iter_->operands_.emplace_back(output, type);
|
||||
void add_output(const Tensor& output, const Backend backend=Backend::Undefined, const ScalarType dtype=ScalarType::Undefined) {
|
||||
iter_->operands_.emplace_back(output, backend, dtype);
|
||||
iter_->num_outputs_++;
|
||||
}
|
||||
|
||||
void add_input(const Tensor& input, const Type* type=nullptr) {
|
||||
iter_->operands_.emplace_back(input, type);
|
||||
void add_input(const Tensor& input, const Backend backend=Backend::Undefined, const ScalarType dtype=ScalarType::Undefined) {
|
||||
iter_->operands_.emplace_back(input, backend, dtype);
|
||||
}
|
||||
|
||||
void dont_compute_common_dtype() {
|
||||
|
@ -1025,7 +1025,7 @@ class TestCuda(TestCase):
|
||||
|
||||
self.assertEqual(x * y, 4.5)
|
||||
self.assertEqual(y * x, 4.5)
|
||||
with self.assertRaisesRegex(RuntimeError, "doesn't match the desired type"):
|
||||
with self.assertRaisesRegex(RuntimeError, "doesn't match the desired"):
|
||||
y *= x
|
||||
x *= y
|
||||
self.assertEqual(x, 4.5)
|
||||
@ -2059,15 +2059,13 @@ class TestCuda(TestCase):
|
||||
def test_sum_cpu_gpu_mismatch(self):
|
||||
x = torch.randn(20, dtype=torch.float32, device='cuda')
|
||||
y = torch.randn(1, dtype=torch.float32)
|
||||
with self.assertRaisesRegex(RuntimeError, 'expected type'
|
||||
' torch.FloatTensor but got'
|
||||
' torch.cuda.FloatTensor'):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'expected backend CPU and dtype Float but got backend CUDA and dtype Float'):
|
||||
torch.sum(x, dim=[0], dtype=torch.float32, out=y)
|
||||
# makeing sure half to float promotion is also properly working.
|
||||
x = x.half()
|
||||
with self.assertRaisesRegex(RuntimeError, 'expected type'
|
||||
' torch.FloatTensor but got'
|
||||
' torch.cuda.HalfTensor'):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'expected backend CPU and dtype Float but got backend CUDA and dtype Half'):
|
||||
torch.sum(x, dim=[0], dtype=torch.float32, out=y)
|
||||
|
||||
@skipIfRocm
|
||||
|
Reference in New Issue
Block a user