mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#77756)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77756 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
a8c929b0a6
commit
df1f9b9840
@ -221,7 +221,7 @@ class TORCH_API TensorBase {
|
||||
return impl_->sizes();
|
||||
}
|
||||
c10::SymIntArrayRef sym_sizes() const {
|
||||
return c10::SymIntArrayRef(reinterpret_cast<const SymInt*>(sizes().data()), sizes().size());
|
||||
return impl_->sym_sizes();
|
||||
}
|
||||
IntArrayRef strides() const {
|
||||
return impl_->strides();
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/InferenceMode.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/core/WrapDimMinimal.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/util/Optional.h>
|
||||
@ -371,6 +372,13 @@ IntArrayRef TensorImpl::sizes_custom() const {
|
||||
TORCH_CHECK(
|
||||
false, "Tensors of type ", tensorimpl_type_name(), " do not have sizes");
|
||||
}
|
||||
c10::SymIntArrayRef TensorImpl::sym_sizes_custom() const {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Tensors of type ",
|
||||
tensorimpl_type_name(),
|
||||
" do not have sym sizes");
|
||||
}
|
||||
IntArrayRef TensorImpl::strides_custom() const {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
||||
@ -548,6 +548,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return sizes_default();
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef sym_sizes() const {
|
||||
if (C10_UNLIKELY(
|
||||
sizes_strides_policy_ >=
|
||||
static_cast<uint8_t>(SizesStridesPolicy::CustomSizes))) {
|
||||
return sym_sizes_custom();
|
||||
}
|
||||
return sym_sizes_default();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a reference to the strides of this tensor. This reference remains
|
||||
* valid as long as the tensor is live and not restrided.
|
||||
@ -655,6 +664,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const;
|
||||
// sizes_strides_policy_ >= CustomSizes
|
||||
virtual IntArrayRef sizes_custom() const;
|
||||
virtual c10::SymIntArrayRef sym_sizes_custom() const;
|
||||
virtual int64_t dim_custom() const;
|
||||
virtual int64_t numel_custom() const;
|
||||
|
||||
@ -675,6 +685,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
inline IntArrayRef sizes_default() const {
|
||||
return sizes_and_strides_.sizes_arrayref();
|
||||
}
|
||||
inline c10::SymIntArrayRef sym_sizes_default() const {
|
||||
return c10::SymIntArrayRef(
|
||||
reinterpret_cast<const c10::SymInt*>(sizes_and_strides_.sizes_data()),
|
||||
sizes_and_strides_.size());
|
||||
}
|
||||
inline int64_t dim_default() const {
|
||||
return sizes_and_strides_.size();
|
||||
}
|
||||
|
||||
@ -430,8 +430,7 @@ class TORCH_API Tensor final {
|
||||
}
|
||||
|
||||
inline c10::SymIntArrayRef sym_sizes() const {
|
||||
auto sizes = impl_.get()->sizes();
|
||||
return c10::SymIntArrayRef(reinterpret_cast<const c10::SymInt*>(sizes.data()), sizes.size());
|
||||
return impl_->sym_sizes();
|
||||
}
|
||||
|
||||
inline int64_t size_from_dim(int k) const {
|
||||
|
||||
@ -94,6 +94,19 @@ TEST(LazyDynamicOpsTest, NarrowCopy) {
|
||||
AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM));
|
||||
}
|
||||
|
||||
TEST(LazyDynamicOpsTest, NarrowCopyViaSymSizes) {
|
||||
auto xc = torch::rand({10});
|
||||
auto x = xc.to(kLazy);
|
||||
const size_t Y_DIM = 3;
|
||||
const size_t X_DIM_INDEX = 0;
|
||||
auto y = torch::rand({Y_DIM}).to(kLazy);
|
||||
auto z = x.narrow_copy(X_DIM_INDEX, 0, y.sym_sizes()[0]);
|
||||
auto zc = xc.narrow_copy(X_DIM_INDEX, 0, Y_DIM);
|
||||
ASSERT_EQ(z.sizes()[0], xc.sizes()[0]); // note, xc not zc
|
||||
// shape inference assumes narrow_copy can copy the whole tensor
|
||||
AllClose(z.cpu(), zc);
|
||||
}
|
||||
|
||||
TEST_F(LazyOpsTest, TestScalarTensor) {
|
||||
torch::Tensor scalar_tensor = torch::scalar_tensor(
|
||||
1., torch::TensorOptions(torch::kFloat).device(DefaultDevice()));
|
||||
|
||||
@ -37,7 +37,7 @@ class TestLazyReuseIr(TestCase):
|
||||
torch._lazy.mark_step()
|
||||
|
||||
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 16
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 14
|
||||
metrics.reset()
|
||||
torch._lazy.ir_cache.reset()
|
||||
|
||||
@ -66,8 +66,7 @@ class TestLazyReuseIr(TestCase):
|
||||
torch._lazy.mark_step()
|
||||
|
||||
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 10
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 4
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 8
|
||||
metrics.reset()
|
||||
torch._lazy.ir_cache.reset()
|
||||
|
||||
@ -97,7 +96,7 @@ class TestLazyReuseIr(TestCase):
|
||||
torch._lazy.mark_step()
|
||||
|
||||
torch.testing.assert_close(z.cpu(), z_lazy.cpu())
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 11
|
||||
assert metrics.counter_value("IrNodeReused_torch::lazy::AddTensor") >= 8
|
||||
metrics.reset()
|
||||
torch._lazy.ir_cache.reset()
|
||||
torch._lazy.config.set_force_fallback("")
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/lazy/core/tensor_util.h>
|
||||
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||
|
||||
namespace torch {
|
||||
namespace lazy {
|
||||
@ -83,6 +84,14 @@ LTCTensorImpl::LTCTensorImpl(LazyTensor&& tensor)
|
||||
// according to https://github.com/pytorch/xla/pull/2682.
|
||||
is_non_overlapping_and_dense_ = false;
|
||||
set_sizes_strides_policy(SizesStridesPolicy::CustomSizes);
|
||||
|
||||
auto rank = tensor_->shape().Get().sizes().size();
|
||||
sym_sizes_.reserve(rank);
|
||||
for (auto i: c10::irange(rank)) {
|
||||
auto dim_node = getBackend()->GetIrBuilder()->MakeSizeNode(this->tensor_->GetIrValue(), i);
|
||||
auto sn = std::make_shared<torch::lazy::SymbolicIntNode>(dim_node);
|
||||
sym_sizes_.push_back(sn->toSymInt());
|
||||
}
|
||||
}
|
||||
|
||||
void LTCTensorImpl::set_tensor(const LazyTensorPtr& lazy_tensor) {
|
||||
@ -127,6 +136,10 @@ void LTCTensorImpl::shallow_copy_from(
|
||||
generation_ = 0;
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef LTCTensorImpl::sym_sizes_custom() const {
|
||||
return c10::SymIntArrayRef(sym_sizes_.data(), sym_sizes_.size());
|
||||
}
|
||||
|
||||
void LTCTensorImpl::setup_size_properties() {
|
||||
size_t generation = tensor_->generation();
|
||||
if (generation != generation_) {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
|
||||
#include <torch/csrc/lazy/core/tensor.h>
|
||||
@ -38,6 +39,8 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
|
||||
int64_t numel_custom() const override;
|
||||
bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
|
||||
|
||||
virtual c10::SymIntArrayRef sym_sizes_custom() const override;
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
const at::Storage& storage() const override { return tensor_->Storage(); }
|
||||
bool has_storage() const override { return tensor_->Storage(); }
|
||||
@ -47,6 +50,7 @@ class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
|
||||
void setup_size_properties();
|
||||
|
||||
LazyTensorPtr tensor_;
|
||||
std::vector<c10::SymInt> sym_sizes_;
|
||||
size_t generation_ {0};
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user