Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#77756)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77756
Approved by: https://github.com/desertfire
This commit is contained in:
Nikolay Korovaiko
2022-05-20 05:39:03 +00:00
committed by PyTorch MergeBot
parent a8c929b0a6
commit df1f9b9840
8 changed files with 58 additions and 7 deletions

View File

@ -221,7 +221,7 @@ class TORCH_API TensorBase {
return impl_->sizes();
}
c10::SymIntArrayRef sym_sizes() const {
return c10::SymIntArrayRef(reinterpret_cast<const SymInt*>(sizes().data()), sizes().size());
return impl_->sym_sizes();
}
IntArrayRef strides() const {
return impl_->strides();

View File

@ -2,6 +2,7 @@
#include <c10/core/Backend.h>
#include <c10/core/InferenceMode.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/Optional.h>
@ -371,6 +372,13 @@ IntArrayRef TensorImpl::sizes_custom() const {
TORCH_CHECK(
false, "Tensors of type ", tensorimpl_type_name(), " do not have sizes");
}
c10::SymIntArrayRef TensorImpl::sym_sizes_custom() const {
TORCH_CHECK(
false,
"Tensors of type ",
tensorimpl_type_name(),
" do not have sym sizes");
}
IntArrayRef TensorImpl::strides_custom() const {
TORCH_CHECK(
false,

View File

@ -548,6 +548,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return sizes_default();
}
c10::SymIntArrayRef sym_sizes() const {
if (C10_UNLIKELY(
sizes_strides_policy_ >=
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
return sym_sizes_custom();
}
return sym_sizes_default();
}
/**
* Return a reference to the strides of this tensor. This reference remains
* valid as long as the tensor is live and not restrided.
@ -655,6 +664,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const;
// sizes_strides_policy_ >= CustomSizes
virtual IntArrayRef sizes_custom() const;
virtual c10::SymIntArrayRef sym_sizes_custom() const;
virtual int64_t dim_custom() const;
virtual int64_t numel_custom() const;
@ -675,6 +685,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
inline IntArrayRef sizes_default() const {
return sizes_and_strides_.sizes_arrayref();
}
inline c10::SymIntArrayRef sym_sizes_default() const {
return c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(sizes_and_strides_.sizes_data()),
sizes_and_strides_.size());
}
inline int64_t dim_default() const {
return sizes_and_strides_.size();
}

View File

@ -430,8 +430,7 @@ class TORCH_API Tensor final {
}
inline c10::SymIntArrayRef sym_sizes() const {
auto sizes = impl_.get()->sizes();
return c10::SymIntArrayRef(reinterpret_cast<const c10::SymInt*>(sizes.data()), sizes.size());
return impl_->sym_sizes();
}
inline int64_t size_from_dim(int k) const {

View File

@ -94,6 +94,19 @@ TEST(LazyDynamicOpsTest, NarrowCopy) {
AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM));
}
TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
auto xc = torch::rand({10});
auto x = xc.to(kLazy);
const size_t Y_DIM = 3;
const size_t X_DIM_INDEX = 0;
auto y = torch::rand({Y_DIM}).to(kLazy);
auto z = x.narrow_copy(X_DIM_INDEX, 0, y.sym_sizes()[0]);
auto zc = xc.narrow_copy(X_DIM_INDEX, 0, Y_DIM);
ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc
// shape inference assumes narrow_copy can copy the whole tensor
AllClose(z.cpu(), zc);
}
TEST_F(LazyOpsTest, TestScalarTensor) {
torch::Tensor scalar_tensor = torch::scalar_tensor(
1., torch::TensorOptions(torch::kFloat).device(DefaultDevice()));

View File

@ -37,7 +37,7 @@ class TestLazyReuseIr(TestCase):
torch._lazy.mark_step()
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 16
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 14
metrics.reset()
torch._lazy.ir_cache.reset()
@ -66,8 +66,7 @@ class TestLazyReuseIr(TestCase):
torch._lazy.mark_step()
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 10
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 4
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 8
metrics.reset()
torch._lazy.ir_cache.reset()
@ -97,7 +96,7 @@ class TestLazyReuseIr(TestCase):
torch._lazy.mark_step()
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 11
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 8
metrics.reset()
torch._lazy.ir_cache.reset()
torch._lazy.config.set_force_fallback("")

View File

@ -6,6 +6,7 @@
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/core/ir_builder.h>
namespace torch {
namespace lazy {
@ -83,6 +84,14 @@ LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
// according to https://github.com/pytorch/xla/pull/2682.
is_non_overlapping_and_dense_ = false;
set_sizes_strides_policy(SizesStridesPolicy::CustomSizes);
auto rank = tensor_->shape().Get().sizes().size();
sym_sizes_.reserve(rank);
for (auto i: c10::irange(rank)) {
auto dim_node = getBackend()->GetIrBuilder()->MakeSizeNode(this->tensor_->GetIrValue(), i);
auto sn = std::make_shared<torch::lazy::SymbolicIntNode>(dim_node);
sym_sizes_.push_back(sn->toSymInt());
}
}
void LTCTensorImpl::set_tensor(const LazyTensorPtr& lazy_tensor) {
@ -127,6 +136,10 @@ void LTCTensorImpl::shallow_copy_from(
generation_ = 0;
}
c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
}
void LTCTensorImpl::setup_size_properties() {
size_t generation = tensor_->generation();
if (generation != generation_) {

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <torch/csrc/lazy/core/tensor.h>
@ -38,6 +39,8 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
int64_t numel_custom() const override;
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
virtual c10::SymIntArrayRef sym_sizes_custom() const override;
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
const at::Storage& storage() const override { return tensor_->Storage(); }
bool has_storage() const override { return tensor_->Storage(); }
@ -47,6 +50,7 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
void setup_size_properties();
LazyTensorPtr tensor_;
std::vector<c10::SymInt> sym_sizes_;
size_t generation_ {0};
};