mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Meta Tensor] fix meta inplace set storage (#123880)
Fixes #123879 Co-authored-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/123880 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c3c4465f50
commit
c511aed27f
@ -421,9 +421,19 @@ Tensor& set_storage_meta__symint(Tensor& result, Storage storage, c10::SymInt st
|
||||
// it. TODO: Actually this might not quite be correct if we use special
|
||||
// pointers to track whether or not fake cuda tensors are pinned or not
|
||||
const auto itemsize = result.dtype().itemsize();
|
||||
c10::SymInt size_bytes = at::detail::computeStorageNbytes(
|
||||
c10::SymInt new_size_bytes = at::detail::computeStorageNbytes(
|
||||
size, stride, itemsize, std::move(storage_offset));
|
||||
storage.set_nbytes(std::move(size_bytes));
|
||||
// TODO: When there are unbacked SymInts, we unconditionally skip the
|
||||
// setter. This is technically wrong, but we cannot conveniently test
|
||||
// the real condition in many cases, because a lot of people are using
|
||||
// set_ just to swizzle metadata on a tensor, they didn't actually want
|
||||
// to see if they need to resize the storage.
|
||||
//
|
||||
// The old behavior was to unconditionally set_nbytes, but I think not
|
||||
// setting it is more safe.
|
||||
if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && TORCH_GUARD_SIZE_OBLIVIOUS(new_size_bytes.sym_gt(storage.sym_nbytes()))) {
|
||||
storage.set_nbytes(std::move(new_size_bytes));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ import functools
|
||||
import itertools
|
||||
import unittest
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
@ -37,6 +39,105 @@ def traceable_subclass(c):
|
||||
return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
|
||||
|
||||
|
||||
def get_jagged_tensor(nested_size, offsets, requires_grad=True):
|
||||
# Makes a jagged tensor with N constituent tensors with size
|
||||
# as specified ((S0, S1, S2), D)
|
||||
D = nested_size[1]
|
||||
out = []
|
||||
for s in nested_size[0]:
|
||||
out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64))
|
||||
return jagged_from_list(out, offsets)
|
||||
|
||||
|
||||
def get_view_test_cases():
|
||||
# Test all cases with both an NT base and a dense base
|
||||
# Subclass -> Subclass
|
||||
# Dense -> Subclass
|
||||
|
||||
# NB: Don't close over loop variables, they will not get copied into the
|
||||
# closure
|
||||
#
|
||||
# NB: These return functions so we don't generate tensors during test
|
||||
# collection time
|
||||
|
||||
def mk_basic(base_is_nt):
|
||||
# There are three cases to consider here based on the logic in
|
||||
# meta_utils.py
|
||||
#
|
||||
# (1) basic case:
|
||||
# view is not a leaf and has the same requires grad as its basic case
|
||||
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
|
||||
x = x.clone() if base_is_nt else x
|
||||
assert not x.is_leaf
|
||||
return x.unsqueeze(-1)
|
||||
|
||||
def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2):
|
||||
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1)
|
||||
x = x.clone() if base_is_nt else x
|
||||
with torch.no_grad():
|
||||
x_view = x.unsqueeze(-1)
|
||||
# The issue is this doesn't quite work
|
||||
x_view.requires_grad_(requires_grad_2)
|
||||
|
||||
return x_view
|
||||
|
||||
def mk_obscure(base_is_nt):
|
||||
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
|
||||
x = x.clone() if base_is_nt else x
|
||||
# intermediate leaf view
|
||||
with torch.no_grad():
|
||||
x_view = x.unsqueeze(-1)
|
||||
x_view.requires_grad_(True)
|
||||
x_view_view = x_view.unsqueeze(-1)
|
||||
return x_view_view
|
||||
|
||||
for base_is_nt in [False, True]:
|
||||
prefix = f"base_is_nt_{base_is_nt}"
|
||||
|
||||
yield partial(mk_basic, base_is_nt), f"{prefix}_basic"
|
||||
|
||||
# (2) leaf view case:
|
||||
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
|
||||
# base w/ requires_grad True or requires_grad False
|
||||
for requires_grad_1, requires_grad_2 in itertools.product(
|
||||
[True, False], repeat=2
|
||||
):
|
||||
yield partial(
|
||||
mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
|
||||
), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
|
||||
|
||||
# (3) obscure case:
|
||||
# view is not a leaf (implies requires_grad True)
|
||||
# base w/ requires_grad False)
|
||||
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
|
||||
|
||||
# Subclass -> Dense
|
||||
yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
|
||||
0
|
||||
].clone(), "subclass_dense"
|
||||
|
||||
# Dense -> Subclass -> Dense -> Subclass
|
||||
def mk_dense_subclass_dense_subclass():
|
||||
values = torch.randn(10, 5)
|
||||
offsets = torch.tensor([0, 3, 6, 10])
|
||||
offsets2 = offsets.clone().detach()
|
||||
return nested_view_from_values_offsets(
|
||||
nested_view_from_values_offsets(values, offsets).values(), offsets
|
||||
)
|
||||
|
||||
yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass"
|
||||
|
||||
def mk_subclass_dense_subclass_dense():
|
||||
x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
|
||||
offsets2 = x.offsets().clone().detach()
|
||||
nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
|
||||
|
||||
yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense"
|
||||
|
||||
|
||||
VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()}
|
||||
|
||||
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
|
||||
compile_full_eager = torch.compile(backend="eager", fullgraph=True)
|
||||
@ -1207,15 +1308,7 @@ instantiate_parametrized_tests(SubclassTests)
|
||||
|
||||
class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
|
||||
# Makes a jagged tensor with N constituent tensors with size
|
||||
# as specified ((S0, S1, S2), D)
|
||||
D = nested_size[1]
|
||||
out = []
|
||||
for s in nested_size[0]:
|
||||
out.append(
|
||||
torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)
|
||||
)
|
||||
return jagged_from_list(out, offsets)
|
||||
return get_jagged_tensor(nested_size, offsets, requires_grad)
|
||||
|
||||
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
|
||||
# Makes a jagged tensor with N constituent tensors with size
|
||||
@ -1369,62 +1462,9 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||
|
||||
torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
|
||||
|
||||
def _get_views(self):
|
||||
# Test all cases with both an NT base and a dense base
|
||||
# Subclass -> Subclass
|
||||
# Dense -> Subclass
|
||||
for base_is_nt in [False, True]:
|
||||
# There are three cases to consider here based on the logic in
|
||||
# meta_utils.py
|
||||
#
|
||||
# (1) basic case:
|
||||
# view is not a leaf and has the same requires grad as its basic case
|
||||
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
|
||||
x = x.clone() if base_is_nt else x
|
||||
self.assertEqual(x.is_leaf, False)
|
||||
yield x.unsqueeze(-1)
|
||||
def _input_view_test(self, nt_view_name):
|
||||
nt_view = VIEW_TEST_CASES[nt_view_name]()
|
||||
|
||||
# (2) leaf view case:
|
||||
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
|
||||
# base w/ requires_grad True or requires_grad False
|
||||
for requires_grad_1, requires_grad_2 in itertools.product(
|
||||
[True, False], repeat=2
|
||||
):
|
||||
x, _ = self._get_jagged_tensor(
|
||||
((2, 3, 4), 3), None, requires_grad=requires_grad_1
|
||||
)
|
||||
x = x.clone() if base_is_nt else x
|
||||
with torch.no_grad():
|
||||
x_view = x.unsqueeze(-1)
|
||||
# The issue is this doesn't quite work
|
||||
x_view.requires_grad_(requires_grad_2)
|
||||
yield x_view
|
||||
|
||||
# (3) obscure case:
|
||||
# view is not a leaf (implies requires_grad True)
|
||||
# base w/ requires_grad False)
|
||||
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
|
||||
x = x.clone() if base_is_nt else x
|
||||
# intermediate leaf view
|
||||
with torch.no_grad():
|
||||
x_view = x.unsqueeze(-1)
|
||||
x_view.requires_grad_(True)
|
||||
x_view_view = x_view.unsqueeze(-1)
|
||||
yield x_view_view
|
||||
|
||||
# Subclass -> Dense
|
||||
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
|
||||
yield x.values()
|
||||
|
||||
# Dense -> Subclass -> Dense -> Subclass
|
||||
values = torch.randn(10, 5)
|
||||
offsets = torch.tensor([0, 3, 6, 10])
|
||||
offsets2 = offsets.clone().detach()
|
||||
yield nested_view_from_values_offsets(
|
||||
nested_view_from_values_offsets(values, offsets).values(), offsets
|
||||
)
|
||||
|
||||
def _input_view_test(self, nt_view):
|
||||
def fn(x):
|
||||
return x.sin()
|
||||
|
||||
@ -1450,7 +1490,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||
|
||||
# varies based on the type of view
|
||||
guard_str = "\n".join(guards)
|
||||
if isinstance(nt_view._base, NestedTensor):
|
||||
if (
|
||||
isinstance(nt_view._base, NestedTensor)
|
||||
or nt_view_name == "subclass_dense"
|
||||
):
|
||||
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
|
||||
else:
|
||||
self.assertExpectedInline(guard_str, """""")
|
||||
@ -1460,9 +1503,12 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||
compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
|
||||
out = compile_fn(nt_view)
|
||||
|
||||
def test_inputs_to_compiled_fn_are_views(self):
|
||||
for nt_view in self._get_views():
|
||||
self._input_view_test(nt_view)
|
||||
@parametrize(
|
||||
"nt_view_name",
|
||||
[k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
|
||||
)
|
||||
def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
|
||||
self._input_view_test(nt_view_name)
|
||||
|
||||
def test_subclass_gives_static_shapes_when_dynamic_false(self):
|
||||
def check_graph(gm, *args):
|
||||
@ -1490,10 +1536,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
|
||||
# are cached onto fake offsets to solve this problem.
|
||||
@unittest.expectedFailure
|
||||
def test_subclass_dense_subclass_dense_view(self):
|
||||
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
|
||||
offsets2 = x.offsets().clone().detach()
|
||||
nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
|
||||
self._input_view_test(nt_view)
|
||||
self._input_view_test("subclass_dense_subclass_dense")
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -286,6 +286,14 @@ class TestMetaConverter(TestCase):
|
||||
m = MetaConverter()(y)
|
||||
self.assertMetadataMatches(m, y)
|
||||
|
||||
def test_inplace_set_storage(self):
|
||||
x = torch.tensor([0, 1], dtype=torch.int64)
|
||||
storage = x.untyped_storage()
|
||||
ssize = storage.size()
|
||||
meta = torch.empty((), dtype=torch.int64)
|
||||
meta.set_(storage, 0, (), ())
|
||||
self.assertEqual(storage.size(), ssize)
|
||||
|
||||
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
|
||||
def test_weakref(self):
|
||||
x = torch.randn(4, 4, 4)
|
||||
|
@ -8,6 +8,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
@ -1213,16 +1214,20 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if metadata.is_neg:
|
||||
torch._C._set_neg(empty, True)
|
||||
|
||||
maybe_suppress: Callable[[], Any] = contextlib.nullcontext
|
||||
if self.shape_env is not None:
|
||||
maybe_suppress = self.shape_env.suppress_guards
|
||||
|
||||
if func.is_view:
|
||||
# For view ops, the storage should be the same as the tensor input.
|
||||
storage = args[cast(int, entry.view_idx)].untyped_storage()
|
||||
with in_kernel_invocation_manager(self):
|
||||
with in_kernel_invocation_manager(self), maybe_suppress():
|
||||
empty.set_(
|
||||
storage, metadata.storage_offset, metadata.shape, metadata.stride
|
||||
)
|
||||
elif metadata.storage_offset != 0:
|
||||
storage = empty.untyped_storage()
|
||||
with in_kernel_invocation_manager(self):
|
||||
with in_kernel_invocation_manager(self), maybe_suppress():
|
||||
empty.set_(
|
||||
storage, metadata.storage_offset, metadata.shape, metadata.stride
|
||||
)
|
||||
|
@ -1265,7 +1265,7 @@ class MetaConverter:
|
||||
mb_fake_mode = maybe_get_fake_mode(r)
|
||||
if mb_fake_mode is not None:
|
||||
maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
|
||||
with maybe_fake_mgr, torch.no_grad():
|
||||
with maybe_fake_mgr, torch.no_grad(), maybe_suppress():
|
||||
r.set_(r_s, storage_offset, sizes, strides)
|
||||
|
||||
if t.grad is not None:
|
||||
|
Reference in New Issue
Block a user