Factor meta conversion through serializable MetaTensorDesc (#122044)

Fixes https://github.com/pytorch/pytorch/issues/121085

This PR pretty involved so pay attention to this description.  At a high
level, the refactor is intended to be mechanical: anywhere in
MetaConverter where previously we took a Tensor as argument, we now take
a MetaTensorDesc, which contains all of the information that we would
have queried off of the Tensor, but placed into a separate data
structure which we can serialize or use to recreate a fake tensor in
a separate fake tensor mode in exact fidelity to the original.

However, this transformation is not always entirely mechanical.  Here
is what you need to pay attention to:

- The memo table from real Tensor -> meta/fake Tensor is now broken
  into two memo tables: real Tensor -> stable int id -> meta/fake
  Tensor.  The stable int id is needed so that when we do serialization,
  we know when tensors/storages alias each other and can ensure we preserve
  this aliasing upon deserialization.

  The way I have implemented changes the weak reference behavior.
  Previously, when either the real Tensor OR the meta/fake Tensor went
  dead, we would remove the entry from the memo table.  Now, this only
  removes entries from one of the two memo tables.  This semantically
  makes sense, because the user may have held on to the stable int id
  out of band, and may expect a real Tensor to continue to be numbered
  consistently / expect to be able to lookup a meta/fake tensor from
  this id.  If this is unacceptable, it may be possible to rejigger
  the memo tables so that we have real Tensor -> stable int id
  and real Tensor -> meta/fake Tensor, but TBH I find the new
  implementation a lot simpler, and arranging the memo tables in this
  way means that I have to muck around with the real tensor to save
  to the memo table; in the current implementation, I never pass the
  Tensor to meta_tensor function AT ALL, which means it is impossible
  to accidentally depend on it.

- When I fill in the fields of MetaTensorDesc in describe_tensor, I need
  to be careful not to poke fields when they are not valid.  Previously,
  preconditions were implicitly checked via the conditional structure
  ("is this sparse? is this nested?") that is tested before we start
  reading attributes.  This structure has to be replicated in
  describe_tensor, and I have almost assuredly gotten it wrong on my
  first try (I'll be grinding through it on CI; a careful audit will
  help too, by auditing that I've tested all the same conditionals that
  the original access was guarded by.)

- I originally submitted https://github.com/pytorch/pytorch/pull/121821
  for the symbolic shapes change, but it turned out the way I did it
  there didn't actually work so well for this PR.  I ended up just
  inlining the symbolic shapes allocation logic into MetaConverter
  (look for calls to maybe_specialize_sym_int_with_hint), maybe there
  is a better way to structure it, but what I really want is to
  just read sizes/strides/offset directly off of MetaTensorDesc; I
  don't want another intermediate data structure.

- Some fields aren't serializable. These are documented as "NOT
  serializable".  ctx/type should morally be serializable and I just
  need to setup a contract with subclasses to let them be serialized.
  The fake_mode is used solely to test if we are refakefying with
  a pre-existing ShapeEnv and we want to reuse the SymInt
  directly--serializing this case is hopeless but I am kind of hoping
  after this refactor we do not need this at all.  view_func is not
  serializable because it's a bound C implemented method.  Joel has
  promised me that this is not too difficult to actually expose as a
  true data structure, but this is the edgiest of edge cases and there
  is no reason to deal with it right now.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044
Approved by: https://github.com/eellison
This commit is contained in:
Edward Z. Yang
2024-03-24 20:10:51 -07:00
committed by PyTorch MergeBot
parent cf06189a2d
commit 5891c5b3a6
10 changed files with 587 additions and 231 deletions

View File

@ -936,9 +936,11 @@ class FakeTensorConverterTest(TestCase):
stor_id = torch._C._storage_id(x_conv)
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
del x
del x_conv
self.assertEqual(len(converter.tensor_memo), 1)
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
del y
del y_conv
self.assertEqual(len(converter.tensor_memo), 0)
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
@ -966,6 +968,8 @@ class FakeTensorConverterTest(TestCase):
x_conv2 = converter(mode, x)
assert x_conv2 is x_conv
del x
del x_conv
del x_conv2
self.assertEqual(len(converter.tensor_memo), 0)
def test_no_active_mode(self):

View File

@ -292,7 +292,19 @@ class TestMetaConverter(TestCase):
self.assertIs(y, z)
self.assertEqual(len(m.tensor_memo), 1)
self.assertEqual(len(m.storage_memo), 1)
self.assertEqual(len(m.describer.lookup_tensor), 1)
self.assertEqual(len(m.describer.lookup_storage), 1)
del x
# Entries from Tensor -> int get deallocated when the real tensor
# disappears...
self.assertEqual(len(m.describer.lookup_tensor), 0)
self.assertEqual(len(m.describer.lookup_storage), 0)
del y
del z
# ... but the int -> FakeTensor entries don't die until the fake
# tensors themselves die (because the user may have held onto the
# int key and are expecting to get a consistent fake tensor in
# this case)
self.assertEqual(len(m.tensor_memo), 0)
self.assertEqual(len(m.storage_memo), 0)
li = []
@ -301,7 +313,13 @@ class TestMetaConverter(TestCase):
li.append(torch.rand([i]))
r.append(m(li[-1]))
self.assertEqual(len(m.tensor_memo), 4)
self.assertEqual(len(m.storage_memo), 4)
self.assertEqual(len(m.describer.lookup_tensor), 4)
self.assertEqual(len(m.describer.lookup_storage), 4)
del li
self.assertEqual(len(m.describer.lookup_tensor), 0)
self.assertEqual(len(m.describer.lookup_storage), 0)
del r
self.assertEqual(len(m.tensor_memo), 0)
self.assertEqual(len(m.storage_memo), 0)

View File

@ -1614,6 +1614,15 @@ class TensorBase(metaclass=_TensorMeta):
nbytes: _int
itemsize: _int
_has_symbolic_sizes_strides: _bool
def _view_func_unsafe(
self,
new_base: Tensor,
symint_visitor_fn: Optional[Callable[[_int], _int]] = None,
tensor_visitor_fn: Optional[Callable[[Tensor], Tensor]] = None
):
...
${tensor_method_hints}
_TensorBase = TensorBase

View File

@ -82,9 +82,9 @@ def fakify(
symbolic_context.dynamic_sizes[i] = DimDynamic.DYNAMIC
src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
sources[(t_id, i)].append(src)
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.debug_name
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.debug_name # type: ignore[assignment]
fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context))
mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr]
return fake
@ -214,6 +214,7 @@ def make_constraints(
)
shape_env = fake_mode.shape_env
assert shape_env.tracked_fakes is not None
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
sources = [tf.source for tf in shape_env.tracked_fakes]
input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]

View File

@ -38,7 +38,6 @@ from torch.utils._python_dispatch import (
from torch.utils._pytree import PyTree, tree_map
from torch.utils._stats import count
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakIdRef
if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -234,28 +233,14 @@ class FakeTensorConverter:
del self.constant_storage_mapping[weak_st]
def _get_memo(self, t):
if WeakIdRef(t) in self.tensor_memo:
out = self.tensor_memo[WeakIdRef(t)]
out._fix_weakref()
return out
return None
tid = self.meta_converter.describer.lookup_tensor.get(t)
if tid is None:
return None
return self.tensor_memo.get(tid)
def set_tensor_memo(self, t, v):
th = WeakIdRef(t)
# hold a weak ref to self, otherwise it will be kept alive
# by the del_ten closure
self_weak_ref = weakref.ref(self)
def del_ten():
self_ref = self_weak_ref()
if self_ref is None:
return
# on shutdown, th may not be in memo
self_ref.tensor_memo.pop(th, None)
weakref.finalize(t, del_ten)
self.tensor_memo[th] = v
tid = self.meta_converter.describer.get_tensor_id(t)
self.meta_converter.tensor_memo[tid] = v
def from_real_tensor(
self,
@ -322,6 +307,8 @@ class FakeTensorConverter:
assert (
t.device.type == "meta"
), f"tensor's device must be `meta`, got {t.device.type} instead"
# This is a bit abusive (this is not the "real" tensor) but whatever,
# the meta tensor should be fresh so there's no way to get it wrong
maybe_memo = self._get_memo(t)
if maybe_memo is not None:
return maybe_memo
@ -860,7 +847,7 @@ class FakeTensorMode(TorchDispatchMode):
# That way when we exit, we know to re-enable the previous fake mode.
self.enter_stack: List[Tuple[bool, Optional[FakeTensorMode]]] = []
self.shape_env = shape_env
self.shape_env: ShapeEnv = shape_env
self.stack = "".join(traceback.format_stack())

File diff suppressed because it is too large Load Diff

View File

@ -2326,51 +2326,9 @@ class ShapeEnv:
introduce new symbolic variables.
"""
# Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
# We create symbols in shape_env using the backed hints behind SymInt.
# Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
# produce_guards will trigger specializations on the outer stuff
# Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
#
# It's probably good for now but it's important to note that this approach has implications for
# the original shape_env when checking guards in different order.
# Example:
# ---------
# Consider a function "opt_f" as shown below:
# @torch.compile()
# def opt_f(x: bool, y: Tensor):
# if x == True:
# return y + torch.randn([4])
# else:
# return y
# Depending on the sequence of calls, we might install two different sets of guards:
# 1. opt_f(False, y):
# - "x == False" (always works for any size y)
# 2. opt_f(True, y):
# - Triggers recompilation and results in guards like:
# - "x == True and y.size(0) == 4"
# - (or "y.size(0) == 4 and x == True")
# The order of checking the guards matters. In this specific example:
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
# we may have an unnessary shape speciliazation for y.
def maybe_specialize_sym_int_with_hint(maybe_sym) -> int:
assert isinstance(maybe_sym, (int, torch.SymInt))
if is_symbolic(maybe_sym):
assert maybe_sym.node.shape_env is not self, \
"expect the symbol is created from an shape env other than current one."
return maybe_sym.node.require_hint()
return maybe_sym
ex_size = tuple(maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
ex_stride = tuple(maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
ex_storage_offset = maybe_specialize_sym_int_with_hint(ex.storage_offset())
ex_size = tuple(self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
ex_stride = tuple(self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
ex_storage_offset = self._maybe_specialize_sym_int_with_hint(ex.storage_offset())
return self._create_symbolic_sizes_strides_storage_offset(
ex_size,
@ -2381,6 +2339,48 @@ class ShapeEnv:
symbolic_context=symbolic_context,
)
# Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
# We create symbols in shape_env using the backed hints behind SymInt.
# Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
# produce_guards will trigger specializations on the outer stuff
# Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
#
# It's probably good for now but it's important to note that this approach has implications for
# the original shape_env when checking guards in different order.
# Example:
# ---------
# Consider a function "opt_f" as shown below:
# @torch.compile()
# def opt_f(x: bool, y: Tensor):
# if x == True:
# return y + torch.randn([4])
# else:
# return y
# Depending on the sequence of calls, we might install two different sets of guards:
# 1. opt_f(False, y):
# - "x == False" (always works for any size y)
# 2. opt_f(True, y):
# - Triggers recompilation and results in guards like:
# - "x == True and y.size(0) == 4"
# - (or "y.size(0) == 4 and x == True")
# The order of checking the guards matters. In this specific example:
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
# we may have an unnessary shape speciliazation for y.
def _maybe_specialize_sym_int_with_hint(self, maybe_sym) -> int:
assert isinstance(maybe_sym, (int, torch.SymInt))
if is_symbolic(maybe_sym):
assert maybe_sym.node.shape_env is not self, \
"expect the symbol is created from an shape env other than current one."
return maybe_sym.node.require_hint()
return maybe_sym
@record_shapeenv_event()
def _create_symbolic_sizes_strides_storage_offset(
self,