mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Symintify getitem and add the required helper functions (#86207)"
This reverts commit fd5085c445c3f1a4c90e55154cf26fe30f52a0ab. Reverted https://github.com/pytorch/pytorch/pull/86207 on behalf of https://github.com/seemethere due to Fails internal tests, see: https://www.internalfb.com/intern/sandcastle/job/22517998926071860/insights
This commit is contained in:
@ -4,7 +4,6 @@
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/core/TensorBody.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
@ -212,7 +211,7 @@ static inline Tensor applySlice(
|
||||
int64_t step,
|
||||
bool disable_slice_optimization,
|
||||
const at::Device& self_device,
|
||||
const c10::optional<SymIntArrayRef>& self_sizes) {
|
||||
const c10::optional<IntArrayRef>& self_sizes) {
|
||||
// TODO: implement negative step
|
||||
TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
|
||||
|
||||
@ -221,10 +220,10 @@ static inline Tensor applySlice(
|
||||
// Skip this optimization if we are tracing, as the trace may be polymorphic
|
||||
// over the shape of the `self` tensor, and we still want to record
|
||||
// the slice.
|
||||
SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
|
||||
int64_t length = (self_device == at::kCPU || self_device == at::kCUDA)
|
||||
? (*self_sizes)[dim]
|
||||
: self.size(dim);
|
||||
if (!disable_slice_optimization && start == 0 && length == stop &&
|
||||
if (!disable_slice_optimization && start == 0 && stop == length &&
|
||||
step == 1) {
|
||||
return self;
|
||||
}
|
||||
@ -238,7 +237,7 @@ static inline Tensor applySelect(
|
||||
int64_t index,
|
||||
int64_t real_dim,
|
||||
const at::Device& /*self_device*/,
|
||||
const c10::optional<SymIntArrayRef>& self_sizes) {
|
||||
const c10::optional<IntArrayRef>& self_sizes) {
|
||||
// See NOTE [nested tensor size for indexing]
|
||||
if (self_sizes.has_value()) {
|
||||
TORCH_CHECK_INDEX(
|
||||
@ -246,9 +245,9 @@ static inline Tensor applySelect(
|
||||
"invalid index of a 0-dim tensor. ",
|
||||
"Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
|
||||
|
||||
auto size = (*self_sizes)[dim];
|
||||
int64_t size = (*self_sizes)[dim];
|
||||
TORCH_CHECK_INDEX(
|
||||
size >= -index && size > index,
|
||||
index >= -size && index < size,
|
||||
"index ",
|
||||
index,
|
||||
" is out of bounds for dimension ",
|
||||
@ -425,7 +424,7 @@ static inline Tensor handleDimInMultiDimIndexing(
|
||||
std::vector<Tensor>& outIndices,
|
||||
bool disable_slice_optimization,
|
||||
const at::Device& original_tensor_device,
|
||||
const c10::optional<SymIntArrayRef>& prev_dim_result_sizes) {
|
||||
const c10::optional<IntArrayRef>& prev_dim_result_sizes) {
|
||||
if (index.is_integer()) {
|
||||
return impl::applySelect(
|
||||
prev_dim_result,
|
||||
@ -509,7 +508,7 @@ static inline Tensor applySlicing(
|
||||
std::vector<Tensor>& outIndices,
|
||||
bool disable_slice_optimization,
|
||||
const at::Device& self_device,
|
||||
const c10::optional<SymIntArrayRef>& self_sizes) {
|
||||
const c10::optional<IntArrayRef>& self_sizes) {
|
||||
int64_t dim = 0;
|
||||
int64_t specified_dims = impl::count_specified_dimensions(indices);
|
||||
|
||||
@ -525,9 +524,9 @@ static inline Tensor applySlicing(
|
||||
for (const auto i : c10::irange(indices.size())) {
|
||||
auto& obj = indices[i];
|
||||
// See NOTE [nested tensor size for indexing]
|
||||
c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
|
||||
? c10::optional<SymIntArrayRef>(c10::nullopt)
|
||||
: c10::optional<SymIntArrayRef>(result.sym_sizes());
|
||||
c10::optional<IntArrayRef> result_sizes = result.is_nested()
|
||||
? c10::optional<IntArrayRef>(c10::nullopt)
|
||||
: c10::optional<IntArrayRef>(result.sizes());
|
||||
result = handleDimInMultiDimIndexing(
|
||||
/*prev_dim_result=*/result,
|
||||
/*original_tensor=*/self,
|
||||
@ -601,9 +600,9 @@ static inline Tensor get_item(
|
||||
// nested tensor does not have a size (yet) so for now we represent its size
|
||||
// as null may need to be changed after we reach a better solution for nested
|
||||
// tensor size
|
||||
c10::optional<SymIntArrayRef> self_sizes = self.is_nested()
|
||||
? c10::optional<SymIntArrayRef>(c10::nullopt)
|
||||
: c10::optional<SymIntArrayRef>(self.sym_sizes());
|
||||
c10::optional<IntArrayRef> self_sizes = self.is_nested()
|
||||
? c10::optional<IntArrayRef>(c10::nullopt)
|
||||
: c10::optional<IntArrayRef>(self.sizes());
|
||||
|
||||
// handle simple types: integers, slices, none, ellipsis, bool
|
||||
if (indices.size() == 1) {
|
||||
@ -664,7 +663,7 @@ static inline void set_item(
|
||||
const Tensor& value,
|
||||
bool disable_slice_optimization = false) {
|
||||
at::Device self_device = self.device();
|
||||
SymIntArrayRef self_sizes = self.sym_sizes();
|
||||
IntArrayRef self_sizes = self.sizes();
|
||||
|
||||
// handle simple types: integers, slices, ellipsis, bool
|
||||
if (indices.size() == 1) {
|
||||
|
@ -1512,49 +1512,39 @@ QuantizerPtr create_subtensor_quantizer(const Tensor& self, bool is_select, int6
|
||||
return quantizer;
|
||||
}
|
||||
|
||||
Tensor select(const Tensor& self, int64_t dim, int64_t index_) {
|
||||
Tensor select(const Tensor& self, int64_t dim, int64_t index) {
|
||||
int64_t ndim = self.dim();
|
||||
if (ndim == 0) {
|
||||
TORCH_CHECK_INDEX(false, "select() cannot be applied to a 0-dim tensor.");
|
||||
}
|
||||
dim = maybe_wrap_dim(dim, ndim);
|
||||
auto size = self.sym_sizes()[dim];
|
||||
if (size < -index_ || size <= index_) {
|
||||
auto size = self.size(dim);
|
||||
if (index < -size || index >= size) {
|
||||
if (self.has_names() && self.names()[dim] != Dimname::wildcard()) {
|
||||
TORCH_CHECK_INDEX(false, "select(): index ", index_, " out of range for tensor of size ",
|
||||
TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
|
||||
self.sizes(), " at dimension ", self.names()[dim]);
|
||||
}
|
||||
TORCH_CHECK_INDEX(false, "select(): index ", index_, " out of range for tensor of size ",
|
||||
TORCH_CHECK_INDEX(false, "select(): index ", index, " out of range for tensor of size ",
|
||||
self.sizes(), " at dimension ", dim);
|
||||
}
|
||||
SymInt index = index_;
|
||||
if (index < 0) {
|
||||
index += size;
|
||||
}
|
||||
if (self.is_sparse()) {
|
||||
return select_sparse(self, dim, index.guard_int(__FILE__, __LINE__));
|
||||
return select_sparse(self, dim, index);
|
||||
}
|
||||
DimVector sizes(self.sizes().begin(), self.sizes().end());
|
||||
DimVector strides(self.strides().begin(), self.strides().end());
|
||||
auto storage_offset = self.storage_offset() + index * strides[dim];
|
||||
sizes.erase(sizes.begin() + dim);
|
||||
strides.erase(strides.begin() + dim);
|
||||
|
||||
Tensor result;
|
||||
if (self.is_quantized()) {
|
||||
auto local_index = index.guard_int(__FILE__, __LINE__);
|
||||
|
||||
DimVector sizes(self.sizes().begin(), self.sizes().end());
|
||||
DimVector strides(self.strides().begin(), self.strides().end());
|
||||
auto storage_offset = self.storage_offset() + local_index * strides[dim];
|
||||
sizes.erase(sizes.begin() + dim);
|
||||
strides.erase(strides.begin() + dim);
|
||||
|
||||
auto quantizer = create_subtensor_quantizer(self, true, local_index, local_index + 1, dim, 1);
|
||||
auto quantizer = create_subtensor_quantizer(self, true, index, index + 1, dim, 1);
|
||||
result = as_strided_qtensorimpl(self, sizes, strides, storage_offset, quantizer);
|
||||
} else {
|
||||
std::vector<c10::SymInt> sizes(self.sym_sizes().begin(), self.sym_sizes().end());
|
||||
std::vector<c10::SymInt> strides(self.sym_strides().begin(), self.sym_strides().end());
|
||||
auto storage_offset = self.sym_storage_offset() + index * strides[dim];
|
||||
sizes.erase(sizes.begin() + dim);
|
||||
strides.erase(strides.begin() + dim);
|
||||
|
||||
result = self.as_strided_symint(sizes, strides, storage_offset);
|
||||
result = self.as_strided(sizes, strides, storage_offset);
|
||||
}
|
||||
namedinference::propagate_names_except(result, self, {dim});
|
||||
return result;
|
||||
|
@ -37,7 +37,6 @@ from common_utils import (
|
||||
skip,
|
||||
skipOps,
|
||||
)
|
||||
from torch._subclasses.fake_tensor import DynamicOutputShapeException
|
||||
|
||||
USE_TORCHVISION = False
|
||||
try:
|
||||
@ -725,6 +724,7 @@ aot_autograd_failures = {
|
||||
}
|
||||
|
||||
symbolic_aot_autograd_failures = {
|
||||
xfail('__getitem__', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('__rmatmul__', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('addbmm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('addcdiv', ''), # aten.fill_.Scalar - couldn't find symbolic meta function/decomposition
|
||||
@ -790,6 +790,7 @@ symbolic_aot_autograd_failures = {
|
||||
xfail('hsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('index_copy', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('index_fill', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('index_put', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('index_select', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
@ -986,7 +987,11 @@ symbolic_aot_autograd_failures = {
|
||||
xfail('scatter_reduce', 'sum'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomp...
|
||||
xfail('segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
|
||||
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
|
||||
xfail('select', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('select_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('slice', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('slice_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('sort', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('special.entr', ''), # aten.special_entr.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('special.erfcx', ''), # aten.special_erfcx.default - couldn't find symbolic meta function/decompos...
|
||||
@ -999,6 +1004,7 @@ symbolic_aot_autograd_failures = {
|
||||
xfail('split', 'list_args'), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('split_with_sizes', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('squeeze', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail('stack', ''), # aten.select.int - couldn't find symbolic meta function/decomposition
|
||||
xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides
|
||||
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
@ -1067,33 +1073,30 @@ def _test_aot_autograd_helper(self, device, dtype, op):
|
||||
|
||||
compiled_f = compiled_function(f, nop, nop)
|
||||
|
||||
try:
|
||||
reset_grads()
|
||||
call_forwards_backwards(compiled_f)
|
||||
compiled_grad = get_grads(args)
|
||||
reset_grads()
|
||||
call_forwards_backwards(compiled_f)
|
||||
compiled_grad = get_grads(args)
|
||||
|
||||
reset_grads()
|
||||
call_forwards_backwards(f)
|
||||
orig_grad = get_grads(args)
|
||||
self.assertEqual(orig_grad, compiled_grad)
|
||||
reset_grads()
|
||||
call_forwards_backwards(f)
|
||||
orig_grad = get_grads(args)
|
||||
self.assertEqual(orig_grad, compiled_grad)
|
||||
|
||||
def create_new_arg(x):
|
||||
if isinstance(x, torch.Tensor) and x.dtype == torch.float32:
|
||||
return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad)
|
||||
return x
|
||||
def create_new_arg(x):
|
||||
if isinstance(x, torch.Tensor) and x.dtype == torch.float32:
|
||||
return x.detach().uniform_(0, 1).requires_grad_(x.requires_grad)
|
||||
return x
|
||||
|
||||
args = pytree.tree_map(create_new_arg, args)
|
||||
args = pytree.tree_map(create_new_arg, args)
|
||||
|
||||
reset_grads()
|
||||
call_forwards_backwards(compiled_f)
|
||||
compiled_grad = get_grads(args)
|
||||
reset_grads()
|
||||
call_forwards_backwards(compiled_f)
|
||||
compiled_grad = get_grads(args)
|
||||
|
||||
reset_grads()
|
||||
call_forwards_backwards(f)
|
||||
orig_grad = get_grads(args)
|
||||
self.assertEqual(orig_grad, compiled_grad)
|
||||
except DynamicOutputShapeException:
|
||||
self.skipTest("Dynamic output shape operation in trace")
|
||||
reset_grads()
|
||||
call_forwards_backwards(f)
|
||||
orig_grad = get_grads(args)
|
||||
self.assertEqual(orig_grad, compiled_grad)
|
||||
|
||||
class TestEagerFusionOpInfo(AOTTestCase):
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
|
@ -5269,16 +5269,13 @@ for shape in [(1,), ()]:
|
||||
self.assertEqual(out.grad_fn._saved_indices, (None, indices)) # c10::List<c10::optional<Tensor>> -> Tuple[Tensor?]
|
||||
self.assertIsInstance(out.grad_fn._saved_indices[1], torch.Tensor)
|
||||
self.assertIsInstance(out.grad_fn._raw_saved_indices[1], torch._C._autograd.SavedTensor)
|
||||
self.assertEqual(out.grad_fn._saved_self_sym_sizes, a.shape) # SymIntArrayRef -> Tuple[SymInt]
|
||||
self.assertIsInstance(out.grad_fn._saved_self_sym_sizes[0], int)
|
||||
self.assertEqual(out.grad_fn._saved_self_sizes, a.shape) # IntArrayRef -> Tuple[int]
|
||||
self.assertIsInstance(out.grad_fn._saved_self_sizes[0], int)
|
||||
|
||||
out.grad_fn._raw_saved_indices[1].register_hooks(lambda x: x, lambda x: x)
|
||||
with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
|
||||
out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x)
|
||||
|
||||
out = a.mean()
|
||||
self.assertEqual(out.grad_fn._saved_self_sizes, a.shape) # IntArrayRef -> Tuple[int]
|
||||
|
||||
a = torch.ones(2, 2, requires_grad=True)
|
||||
out = a * a
|
||||
out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
|
||||
@ -5297,24 +5294,6 @@ for shape in [(1,), ()]:
|
||||
else:
|
||||
self.assertIsNone(out.grad_fn._saved_scales) # c10::optional<ArrayRef<double>> -> float[]?
|
||||
|
||||
a = torch.ones(1, 1, 3, 3, requires_grad=True)
|
||||
out = nn.Conv2d(1, 1, 3)(a)
|
||||
self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (1,)) # c10::optional<SymIntArrayRef> -> SymInt[]?
|
||||
out = nn.Conv2d(1, 1, 3, bias=False)(a)
|
||||
# TODO: This is BAD! we converted a c10::nullopt into a (0,)
|
||||
self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (0,))
|
||||
|
||||
a = torch.ones(1, 3, 3, requires_grad=True)
|
||||
out = torch.addbmm(a.squeeze(0), a, a)
|
||||
self.assertEqual(out.grad_fn._saved_batch1_argsize_0, 1) # int64_t
|
||||
self.assertEqual(out.grad_fn._saved_batch1_argsize_1, 3) # int64_t
|
||||
|
||||
a = torch.ones(1, 1, 3, 3, requires_grad=True)
|
||||
out = torch.nn.functional.unfold(a, 3)
|
||||
self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_2, 3) # SymInt
|
||||
self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_1, 3) # SymInt
|
||||
|
||||
a = torch.ones(1, 1, 2, requires_grad=True)
|
||||
out = torch.nn.functional.interpolate(a, scale_factor=0.5, mode="linear")
|
||||
self.assertIsNone(out.grad_fn._saved_output_size)
|
||||
self.assertEqual(out.grad_fn._saved_scale_factors, (0.5,))
|
||||
|
@ -1032,6 +1032,7 @@ symbolic_tensor_failures = {
|
||||
xfail('linalg.eig'),
|
||||
xfail('linalg.eigvals'),
|
||||
skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel
|
||||
xfail('__getitem__', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition
|
||||
@ -1108,6 +1109,7 @@ symbolic_tensor_failures = {
|
||||
xfail('hsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('index_copy', ''), # Expected a long tensor for index, but got Float
|
||||
xfail('index_fill', ''), # aten.index_fill.int_Scalar - couldn't find symbolic meta function/decomposition
|
||||
xfail('index_reduce', ''), # Float
|
||||
xfail('inner', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('isclose', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
|
||||
@ -1257,6 +1259,9 @@ symbolic_tensor_failures = {
|
||||
xfail('scatter_reduce', 'sum'), # aten.scatter_reduce.two - couldn't find symbolic meta function/decomposition
|
||||
xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ...
|
||||
xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('select', ''), # aten.select.int - couldn't find symbolic meta function/decomposition
|
||||
xfail('select_scatter', ''), # aten.select_scatter.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('slice_scatter', ''), # aten.slice_scatter.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('sort', ''), # aten.sort.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition
|
||||
|
@ -781,7 +781,7 @@
|
||||
other: -grad * exp((self - 1) * log(other) - other - lgamma(self))
|
||||
|
||||
- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
||||
self: index_backward(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad)
|
||||
self: index_backward(grad.new_zeros(self.sizes(), self.options()), indices, grad)
|
||||
result: auto_linear
|
||||
|
||||
- name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
|
||||
@ -1388,7 +1388,7 @@
|
||||
- name: select.int(Tensor(a) self, int dim, int index) -> Tensor(a)
|
||||
dispatch:
|
||||
Default:
|
||||
self: select_backward_symint(grad, self.sym_sizes(), dim, index)
|
||||
self: select_backward(grad, self.sizes(), dim, index)
|
||||
result: auto_linear
|
||||
AutogradNestedTensor:
|
||||
self: _nested_select_backward(grad, self, dim, index)
|
||||
|
@ -814,24 +814,12 @@ def saved_variables(
|
||||
),
|
||||
# replace self.size(2) with self_size_2
|
||||
(
|
||||
r"{}.size\((-?\w+)\)",
|
||||
r"{}.size\((\w+)\)",
|
||||
{
|
||||
"suffix": lambda m: "_argsize_{}".format(
|
||||
m.groups()[0].replace("-", "minus_")
|
||||
),
|
||||
"suffix": lambda m: "_argsize_{}".format(*m.groups()),
|
||||
"nctype": lambda name: NamedCType(name, BaseCType(longT)),
|
||||
},
|
||||
),
|
||||
# replace self.sym_size(2) with self_sym_size_2
|
||||
(
|
||||
r"{}.sym_size\((-?\w+)\)",
|
||||
{
|
||||
"suffix": lambda m: "_sym_argsize_{}".format(
|
||||
m.groups()[0].replace("-", "minus_")
|
||||
),
|
||||
"nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
|
||||
},
|
||||
),
|
||||
# replace self.numel() with self_numel
|
||||
(
|
||||
r"{}.numel\(\)",
|
||||
|
@ -14,8 +14,6 @@ from torch._prims_common import (
|
||||
|
||||
from torch._prims_common.wrappers import out_wrapper
|
||||
from torch._refs import _broadcast_shapes
|
||||
|
||||
from torch._subclasses.fake_tensor import check_no_bool_index_tensors
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -557,7 +555,6 @@ def vdot(self, other):
|
||||
# get shape inference through structured kernels
|
||||
@register_meta(aten.index.Tensor, register_dispatcher=False)
|
||||
def meta_index_Tensor(self, indices):
|
||||
check_no_bool_index_tensors(aten.index.Tensor, self, indices)
|
||||
check(indices, lambda: "at least one index must be provided")
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
@ -1207,45 +1204,6 @@ def arange_start(start, end, **kwargs):
|
||||
return aten.arange(end - start, **kwargs)
|
||||
|
||||
|
||||
@register_meta(aten.select.int)
|
||||
def meta_select(self, dim, index):
|
||||
ndim = self.dim()
|
||||
check(
|
||||
ndim != 0, lambda: "select() cannot be applied to a 0-dim tensor.", IndexError
|
||||
)
|
||||
|
||||
dim = dim if dim >= 0 else dim + ndim
|
||||
size = self.size(dim)
|
||||
|
||||
check(
|
||||
not (-index > size or index >= size),
|
||||
lambda: f"select(): index {index} out of range for tensor of size "
|
||||
f"{self.size()} at dimension {dim}",
|
||||
IndexError,
|
||||
)
|
||||
|
||||
index = index if index >= 0 else index + size
|
||||
|
||||
new_size = list(self.size())
|
||||
new_stride = list(self.stride())
|
||||
|
||||
new_storage_offset = self.storage_offset() + index * new_stride[dim]
|
||||
del new_size[dim]
|
||||
del new_stride[dim]
|
||||
|
||||
return self.as_strided(new_size, new_stride, new_storage_offset)
|
||||
|
||||
|
||||
@register_meta(aten.select_scatter.default)
|
||||
def meta_select_scatter(self, src, dim, index):
|
||||
return torch.empty_like(self)
|
||||
|
||||
|
||||
@register_meta(aten.slice_scatter.default)
|
||||
def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
|
||||
return torch.empty_like(self)
|
||||
|
||||
|
||||
# We must also trigger meta registrations from PrimTorch ref
|
||||
# decompositions
|
||||
import torch._refs
|
||||
|
@ -176,18 +176,18 @@ static inline Variable applySlicing(
|
||||
variable_list& outIndices,
|
||||
bool is_tracing,
|
||||
const at::Device& self_device,
|
||||
const c10::optional<int64_t>& self_ndim,
|
||||
const c10::optional<IntArrayRef>& self_sizes,
|
||||
int64_t specified_dims) {
|
||||
int64_t size =
|
||||
PyTuple_GET_SIZE(index); // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
int64_t dim = 0;
|
||||
|
||||
// See NOTE [nested tensor size for indexing]
|
||||
if (self_ndim.has_value()) {
|
||||
if (self_sizes.has_value()) {
|
||||
TORCH_CHECK_INDEX(
|
||||
specified_dims <= self_ndim.value(),
|
||||
specified_dims <= (int64_t)self_sizes->size(),
|
||||
"too many indices for tensor of dimension ",
|
||||
self_ndim.value());
|
||||
(int)self_sizes->size());
|
||||
}
|
||||
|
||||
Variable result = self;
|
||||
@ -198,9 +198,9 @@ static inline Variable applySlicing(
|
||||
// nested tensor does not have a size (yet) so for now we represent its size
|
||||
// as null may need to be changed after we reach a better solution for
|
||||
// nested tensor size
|
||||
c10::optional<SymIntArrayRef> result_sizes = result.is_nested()
|
||||
? c10::optional<SymIntArrayRef>(c10::nullopt)
|
||||
: c10::optional<SymIntArrayRef>(result.sym_sizes());
|
||||
c10::optional<IntArrayRef> result_sizes = result.is_nested()
|
||||
? c10::optional<IntArrayRef>(c10::nullopt)
|
||||
: c10::optional<IntArrayRef>(result.sizes());
|
||||
result = at::indexing::handleDimInMultiDimIndexing(
|
||||
/*prev_dim_result=*/result,
|
||||
/*original_tensor=*/self,
|
||||
@ -382,13 +382,17 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
|
||||
if (specified_dims == -1) {
|
||||
return handle_torch_function_indexing(self, holder.get());
|
||||
}
|
||||
// See NOTE [nested tensor size for indexing]
|
||||
c10::optional<IntArrayRef> self_sizes = c10::nullopt;
|
||||
if (!self_.is_nested())
|
||||
self_sizes = self_.sizes();
|
||||
Variable sliced = applySlicing(
|
||||
self_,
|
||||
holder.get(),
|
||||
variableIndices,
|
||||
/*is_tracing=*/is_tracing,
|
||||
self_.device(),
|
||||
self_.ndimension(),
|
||||
self_sizes,
|
||||
specified_dims);
|
||||
if (variableIndices.empty()) {
|
||||
if (sliced.is_same(self_)) {
|
||||
@ -518,7 +522,7 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
|
||||
variableIndices,
|
||||
/*is_tracing=*/is_tracing,
|
||||
self_device,
|
||||
self_.ndimension(),
|
||||
self_.sizes(),
|
||||
specified_dims);
|
||||
if (variableIndices.empty()) {
|
||||
pybind11::gil_scoped_release no_gil;
|
||||
|
Reference in New Issue
Block a user