[Meta Tensor] fix meta inplace set storage (#123880)

Fixes #123879

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123880
Approved by: https://github.com/ezyang
This commit is contained in:
Edward Z. Yang
2024-05-01 06:53:45 +00:00
committed by PyTorch MergeBot
parent c3c4465f50
commit c511aed27f
5 changed files with 146 additions and 77 deletions

View File

@ -421,9 +421,19 @@ Tensor& set_storage_meta__symint(Tensor& result, Storage storage, c10::SymInt st
// it. TODO: Actually this might not quite be correct if we use special
// pointers to track whether or not fake cuda tensors are pinned or not
const auto itemsize = result.dtype().itemsize();
c10::SymInt size_bytes = at::detail::computeStorageNbytes(
c10::SymInt new_size_bytes = at::detail::computeStorageNbytes(
size, stride, itemsize, std::move(storage_offset));
storage.set_nbytes(std::move(size_bytes));
// TODO: When there are unbacked SymInts, we unconditionally skip the
// setter. This is technically wrong, but we cannot conveniently test
// the real condition in many cases, because a lot of people are using
// set_ just to swizzle metadata on a tensor, they didn't actually want
// to see if they need to resize the storage.
//
// The old behavior was to unconditionally set_nbytes, but I think not
// setting it is more safe.
if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && TORCH_GUARD_SIZE_OBLIVIOUS(new_size_bytes.sym_gt(storage.sym_nbytes()))) {
storage.set_nbytes(std::move(new_size_bytes));
}
}
return result;
}

View File

@ -3,6 +3,8 @@ import functools
import itertools
import unittest
from functools import partial
import torch
import torch._dynamo.test_case
@ -37,6 +39,105 @@ def traceable_subclass(c):
return torch._dynamo.config.patch("traceable_tensor_subclasses", {c})
def get_jagged_tensor(nested_size, offsets, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
# as specified ((S0, S1, S2), D)
D = nested_size[1]
out = []
for s in nested_size[0]:
out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64))
return jagged_from_list(out, offsets)
def get_view_test_cases():
# Test all cases with both an NT base and a dense base
# Subclass -> Subclass
# Dense -> Subclass
# NB: Don't close over loop variables, they will not get copied into the
# closure
#
# NB: These return functions so we don't generate tensors during test
# collection time
def mk_basic(base_is_nt):
# There are three cases to consider here based on the logic in
# meta_utils.py
#
# (1) basic case:
# view is not a leaf and has the same requires grad as its basic case
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
x = x.clone() if base_is_nt else x
assert not x.is_leaf
return x.unsqueeze(-1)
def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2):
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1)
x = x.clone() if base_is_nt else x
with torch.no_grad():
x_view = x.unsqueeze(-1)
# The issue is this doesn't quite work
x_view.requires_grad_(requires_grad_2)
return x_view
def mk_obscure(base_is_nt):
x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
x = x.clone() if base_is_nt else x
# intermediate leaf view
with torch.no_grad():
x_view = x.unsqueeze(-1)
x_view.requires_grad_(True)
x_view_view = x_view.unsqueeze(-1)
return x_view_view
for base_is_nt in [False, True]:
prefix = f"base_is_nt_{base_is_nt}"
yield partial(mk_basic, base_is_nt), f"{prefix}_basic"
# (2) leaf view case:
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
# base w/ requires_grad True or requires_grad False
for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2
):
yield partial(
mk_leaf, base_is_nt, requires_grad_1, requires_grad_2
), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}"
# (3) obscure case:
# view is not a leaf (implies requires_grad True)
# base w/ requires_grad False)
yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure"
# Subclass -> Dense
yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[
0
].clone(), "subclass_dense"
# Dense -> Subclass -> Dense -> Subclass
def mk_dense_subclass_dense_subclass():
values = torch.randn(10, 5)
offsets = torch.tensor([0, 3, 6, 10])
offsets2 = offsets.clone().detach()
return nested_view_from_values_offsets(
nested_view_from_values_offsets(values, offsets).values(), offsets
)
yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass"
def mk_subclass_dense_subclass_dense():
x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
offsets2 = x.offsets().clone().detach()
nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense"
VIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()}
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
compile_full_eager = torch.compile(backend="eager", fullgraph=True)
@ -1207,15 +1308,7 @@ instantiate_parametrized_tests(SubclassTests)
class TestNestedTensor(torch._dynamo.test_case.TestCase):
def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
# as specified ((S0, S1, S2), D)
D = nested_size[1]
out = []
for s in nested_size[0]:
out.append(
torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)
)
return jagged_from_list(out, offsets)
return get_jagged_tensor(nested_size, offsets, requires_grad)
def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True):
# Makes a jagged tensor with N constituent tensors with size
@ -1369,62 +1462,9 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
torch.compile(fn, fullgraph=True, backend="aot_eager")(nt)
def _get_views(self):
# Test all cases with both an NT base and a dense base
# Subclass -> Subclass
# Dense -> Subclass
for base_is_nt in [False, True]:
# There are three cases to consider here based on the logic in
# meta_utils.py
#
# (1) basic case:
# view is not a leaf and has the same requires grad as its basic case
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)
x = x.clone() if base_is_nt else x
self.assertEqual(x.is_leaf, False)
yield x.unsqueeze(-1)
def _input_view_test(self, nt_view_name):
nt_view = VIEW_TEST_CASES[nt_view_name]()
# (2) leaf view case:
# the view has to be a leaf (w/ requires_grad True or requires_grad False)
# base w/ requires_grad True or requires_grad False
for requires_grad_1, requires_grad_2 in itertools.product(
[True, False], repeat=2
):
x, _ = self._get_jagged_tensor(
((2, 3, 4), 3), None, requires_grad=requires_grad_1
)
x = x.clone() if base_is_nt else x
with torch.no_grad():
x_view = x.unsqueeze(-1)
# The issue is this doesn't quite work
x_view.requires_grad_(requires_grad_2)
yield x_view
# (3) obscure case:
# view is not a leaf (implies requires_grad True)
# base w/ requires_grad False)
x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False)
x = x.clone() if base_is_nt else x
# intermediate leaf view
with torch.no_grad():
x_view = x.unsqueeze(-1)
x_view.requires_grad_(True)
x_view_view = x_view.unsqueeze(-1)
yield x_view_view
# Subclass -> Dense
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
yield x.values()
# Dense -> Subclass -> Dense -> Subclass
values = torch.randn(10, 5)
offsets = torch.tensor([0, 3, 6, 10])
offsets2 = offsets.clone().detach()
yield nested_view_from_values_offsets(
nested_view_from_values_offsets(values, offsets).values(), offsets
)
def _input_view_test(self, nt_view):
def fn(x):
return x.sin()
@ -1450,7 +1490,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
# varies based on the type of view
guard_str = "\n".join(guards)
if isinstance(nt_view._base, NestedTensor):
if (
isinstance(nt_view._base, NestedTensor)
or nt_view_name == "subclass_dense"
):
self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""")
else:
self.assertExpectedInline(guard_str, """""")
@ -1460,9 +1503,12 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True)
out = compile_fn(nt_view)
def test_inputs_to_compiled_fn_are_views(self):
for nt_view in self._get_views():
self._input_view_test(nt_view)
@parametrize(
"nt_view_name",
[k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"],
)
def test_inputs_to_compiled_fn_are_views(self, nt_view_name):
self._input_view_test(nt_view_name)
def test_subclass_gives_static_shapes_when_dynamic_false(self):
def check_graph(gm, *args):
@ -1490,10 +1536,10 @@ class TestNestedTensor(torch._dynamo.test_case.TestCase):
# are cached onto fake offsets to solve this problem.
@unittest.expectedFailure
def test_subclass_dense_subclass_dense_view(self):
x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone()
offsets2 = x.offsets().clone().detach()
nt_view = nested_view_from_values_offsets(x.values(), offsets2).values()
self._input_view_test(nt_view)
self._input_view_test("subclass_dense_subclass_dense")
instantiate_parametrized_tests(TestNestedTensor)
if __name__ == "__main__":

View File

@ -286,6 +286,14 @@ class TestMetaConverter(TestCase):
m = MetaConverter()(y)
self.assertMetadataMatches(m, y)
def test_inplace_set_storage(self):
x = torch.tensor([0, 1], dtype=torch.int64)
storage = x.untyped_storage()
ssize = storage.size()
meta = torch.empty((), dtype=torch.int64)
meta.set_(storage, 0, (), ())
self.assertEqual(storage.size(), ssize)
@skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
def test_weakref(self):
x = torch.randn(4, 4, 4)

View File

@ -8,6 +8,7 @@ from collections import defaultdict
from dataclasses import dataclass
from typing import (
Any,
Callable,
cast,
Dict,
List,
@ -1213,16 +1214,20 @@ class FakeTensorMode(TorchDispatchMode):
if metadata.is_neg:
torch._C._set_neg(empty, True)
maybe_suppress: Callable[[], Any] = contextlib.nullcontext
if self.shape_env is not None:
maybe_suppress = self.shape_env.suppress_guards
if func.is_view:
# For view ops, the storage should be the same as the tensor input.
storage = args[cast(int, entry.view_idx)].untyped_storage()
with in_kernel_invocation_manager(self):
with in_kernel_invocation_manager(self), maybe_suppress():
empty.set_(
storage, metadata.storage_offset, metadata.shape, metadata.stride
)
elif metadata.storage_offset != 0:
storage = empty.untyped_storage()
with in_kernel_invocation_manager(self):
with in_kernel_invocation_manager(self), maybe_suppress():
empty.set_(
storage, metadata.storage_offset, metadata.shape, metadata.stride
)

View File

@ -1265,7 +1265,7 @@ class MetaConverter:
mb_fake_mode = maybe_get_fake_mode(r)
if mb_fake_mode is not None:
maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
with maybe_fake_mgr, torch.no_grad():
with maybe_fake_mgr, torch.no_grad(), maybe_suppress():
r.set_(r_s, storage_offset, sizes, strides)
if t.grad is not None: