mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Factor meta conversion through serializable MetaTensorDesc (#122044)
Fixes https://github.com/pytorch/pytorch/issues/121085 This PR pretty involved so pay attention to this description. At a high level, the refactor is intended to be mechanical: anywhere in MetaConverter where previously we took a Tensor as argument, we now take a MetaTensorDesc, which contains all of the information that we would have queried off of the Tensor, but placed into a separate data structure which we can serialize or use to recreate a fake tensor in a separate fake tensor mode in exact fidelity to the original. However, this transformation is not always entirely mechanical. Here is what you need to pay attention to: - The memo table from real Tensor -> meta/fake Tensor is now broken into two memo tables: real Tensor -> stable int id -> meta/fake Tensor. The stable int id is needed so that when we do serialization, we know when tensors/storages alias each other and can ensure we preserve this aliasing upon deserialization. The way I have implemented changes the weak reference behavior. Previously, when either the real Tensor OR the meta/fake Tensor went dead, we would remove the entry from the memo table. Now, this only removes entries from one of the two memo tables. This semantically makes sense, because the user may have held on to the stable int id out of band, and may expect a real Tensor to continue to be numbered consistently / expect to be able to lookup a meta/fake tensor from this id. If this is unacceptable, it may be possible to rejigger the memo tables so that we have real Tensor -> stable int id and real Tensor -> meta/fake Tensor, but TBH I find the new implementation a lot simpler, and arranging the memo tables in this way means that I have to muck around with the real tensor to save to the memo table; in the current implementation, I never pass the Tensor to meta_tensor function AT ALL, which means it is impossible to accidentally depend on it. - When I fill in the fields of MetaTensorDesc in describe_tensor, I need to be careful not to poke fields when they are not valid. Previously, preconditions were implicitly checked via the conditional structure ("is this sparse? is this nested?") that is tested before we start reading attributes. This structure has to be replicated in describe_tensor, and I have almost assuredly gotten it wrong on my first try (I'll be grinding through it on CI; a careful audit will help too, by auditing that I've tested all the same conditionals that the original access was guarded by.) - I originally submitted https://github.com/pytorch/pytorch/pull/121821 for the symbolic shapes change, but it turned out the way I did it there didn't actually work so well for this PR. I ended up just inlining the symbolic shapes allocation logic into MetaConverter (look for calls to maybe_specialize_sym_int_with_hint), maybe there is a better way to structure it, but what I really want is to just read sizes/strides/offset directly off of MetaTensorDesc; I don't want another intermediate data structure. - Some fields aren't serializable. These are documented as "NOT serializable". ctx/type should morally be serializable and I just need to setup a contract with subclasses to let them be serialized. The fake_mode is used solely to test if we are refakefying with a pre-existing ShapeEnv and we want to reuse the SymInt directly--serializing this case is hopeless but I am kind of hoping after this refactor we do not need this at all. view_func is not serializable because it's a bound C implemented method. Joel has promised me that this is not too difficult to actually expose as a true data structure, but this is the edgiest of edge cases and there is no reason to deal with it right now. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/122044 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
cf06189a2d
commit
5891c5b3a6
@ -936,9 +936,11 @@ class FakeTensorConverterTest(TestCase):
|
||||
stor_id = torch._C._storage_id(x_conv)
|
||||
self.assertEqual(stor_id, torch._C._storage_id(y_conv))
|
||||
del x
|
||||
del x_conv
|
||||
self.assertEqual(len(converter.tensor_memo), 1)
|
||||
self.assertEqual(len(converter.meta_converter.storage_memo), 1)
|
||||
del y
|
||||
del y_conv
|
||||
self.assertEqual(len(converter.tensor_memo), 0)
|
||||
self.assertEqual(len(converter.meta_converter.storage_memo), 0)
|
||||
|
||||
@ -966,6 +968,8 @@ class FakeTensorConverterTest(TestCase):
|
||||
x_conv2 = converter(mode, x)
|
||||
assert x_conv2 is x_conv
|
||||
del x
|
||||
del x_conv
|
||||
del x_conv2
|
||||
self.assertEqual(len(converter.tensor_memo), 0)
|
||||
|
||||
def test_no_active_mode(self):
|
||||
|
@ -292,7 +292,19 @@ class TestMetaConverter(TestCase):
|
||||
self.assertIs(y, z)
|
||||
self.assertEqual(len(m.tensor_memo), 1)
|
||||
self.assertEqual(len(m.storage_memo), 1)
|
||||
self.assertEqual(len(m.describer.lookup_tensor), 1)
|
||||
self.assertEqual(len(m.describer.lookup_storage), 1)
|
||||
del x
|
||||
# Entries from Tensor -> int get deallocated when the real tensor
|
||||
# disappears...
|
||||
self.assertEqual(len(m.describer.lookup_tensor), 0)
|
||||
self.assertEqual(len(m.describer.lookup_storage), 0)
|
||||
del y
|
||||
del z
|
||||
# ... but the int -> FakeTensor entries don't die until the fake
|
||||
# tensors themselves die (because the user may have held onto the
|
||||
# int key and are expecting to get a consistent fake tensor in
|
||||
# this case)
|
||||
self.assertEqual(len(m.tensor_memo), 0)
|
||||
self.assertEqual(len(m.storage_memo), 0)
|
||||
li = []
|
||||
@ -301,7 +313,13 @@ class TestMetaConverter(TestCase):
|
||||
li.append(torch.rand([i]))
|
||||
r.append(m(li[-1]))
|
||||
self.assertEqual(len(m.tensor_memo), 4)
|
||||
self.assertEqual(len(m.storage_memo), 4)
|
||||
self.assertEqual(len(m.describer.lookup_tensor), 4)
|
||||
self.assertEqual(len(m.describer.lookup_storage), 4)
|
||||
del li
|
||||
self.assertEqual(len(m.describer.lookup_tensor), 0)
|
||||
self.assertEqual(len(m.describer.lookup_storage), 0)
|
||||
del r
|
||||
self.assertEqual(len(m.tensor_memo), 0)
|
||||
self.assertEqual(len(m.storage_memo), 0)
|
||||
|
||||
|
@ -1614,6 +1614,15 @@ class TensorBase(metaclass=_TensorMeta):
|
||||
nbytes: _int
|
||||
itemsize: _int
|
||||
_has_symbolic_sizes_strides: _bool
|
||||
|
||||
def _view_func_unsafe(
|
||||
self,
|
||||
new_base: Tensor,
|
||||
symint_visitor_fn: Optional[Callable[[_int], _int]] = None,
|
||||
tensor_visitor_fn: Optional[Callable[[Tensor], Tensor]] = None
|
||||
):
|
||||
...
|
||||
|
||||
${tensor_method_hints}
|
||||
|
||||
_TensorBase = TensorBase
|
||||
|
@ -82,9 +82,9 @@ def fakify(
|
||||
symbolic_context.dynamic_sizes[i] = DimDynamic.DYNAMIC
|
||||
src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i)
|
||||
sources[(t_id, i)].append(src)
|
||||
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.debug_name
|
||||
mode.shape_env.source_name_to_debug_name[src.name()] = constraint.debug_name # type: ignore[assignment]
|
||||
fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
|
||||
mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context))
|
||||
mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr]
|
||||
return fake
|
||||
|
||||
|
||||
@ -214,6 +214,7 @@ def make_constraints(
|
||||
)
|
||||
|
||||
shape_env = fake_mode.shape_env
|
||||
assert shape_env.tracked_fakes is not None
|
||||
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
|
||||
sources = [tf.source for tf in shape_env.tracked_fakes]
|
||||
input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
|
||||
|
@ -38,7 +38,6 @@ from torch.utils._python_dispatch import (
|
||||
from torch.utils._pytree import PyTree, tree_map
|
||||
from torch.utils._stats import count
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
from torch.utils.weak import WeakIdRef
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
@ -234,28 +233,14 @@ class FakeTensorConverter:
|
||||
del self.constant_storage_mapping[weak_st]
|
||||
|
||||
def _get_memo(self, t):
|
||||
if WeakIdRef(t) in self.tensor_memo:
|
||||
out = self.tensor_memo[WeakIdRef(t)]
|
||||
out._fix_weakref()
|
||||
return out
|
||||
return None
|
||||
tid = self.meta_converter.describer.lookup_tensor.get(t)
|
||||
if tid is None:
|
||||
return None
|
||||
return self.tensor_memo.get(tid)
|
||||
|
||||
def set_tensor_memo(self, t, v):
|
||||
th = WeakIdRef(t)
|
||||
|
||||
# hold a weak ref to self, otherwise it will be kept alive
|
||||
# by the del_ten closure
|
||||
self_weak_ref = weakref.ref(self)
|
||||
|
||||
def del_ten():
|
||||
self_ref = self_weak_ref()
|
||||
if self_ref is None:
|
||||
return
|
||||
# on shutdown, th may not be in memo
|
||||
self_ref.tensor_memo.pop(th, None)
|
||||
|
||||
weakref.finalize(t, del_ten)
|
||||
self.tensor_memo[th] = v
|
||||
tid = self.meta_converter.describer.get_tensor_id(t)
|
||||
self.meta_converter.tensor_memo[tid] = v
|
||||
|
||||
def from_real_tensor(
|
||||
self,
|
||||
@ -322,6 +307,8 @@ class FakeTensorConverter:
|
||||
assert (
|
||||
t.device.type == "meta"
|
||||
), f"tensor's device must be `meta`, got {t.device.type} instead"
|
||||
# This is a bit abusive (this is not the "real" tensor) but whatever,
|
||||
# the meta tensor should be fresh so there's no way to get it wrong
|
||||
maybe_memo = self._get_memo(t)
|
||||
if maybe_memo is not None:
|
||||
return maybe_memo
|
||||
@ -860,7 +847,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# That way when we exit, we know to re-enable the previous fake mode.
|
||||
self.enter_stack: List[Tuple[bool, Optional[FakeTensorMode]]] = []
|
||||
|
||||
self.shape_env = shape_env
|
||||
self.shape_env: ShapeEnv = shape_env
|
||||
|
||||
self.stack = "".join(traceback.format_stack())
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2326,51 +2326,9 @@ class ShapeEnv:
|
||||
introduce new symbolic variables.
|
||||
"""
|
||||
|
||||
# Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
|
||||
# We create symbols in shape_env using the backed hints behind SymInt.
|
||||
|
||||
# Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
|
||||
# produce_guards will trigger specializations on the outer stuff
|
||||
|
||||
# Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
|
||||
#
|
||||
# It's probably good for now but it's important to note that this approach has implications for
|
||||
# the original shape_env when checking guards in different order.
|
||||
|
||||
# Example:
|
||||
# ---------
|
||||
# Consider a function "opt_f" as shown below:
|
||||
|
||||
# @torch.compile()
|
||||
# def opt_f(x: bool, y: Tensor):
|
||||
# if x == True:
|
||||
# return y + torch.randn([4])
|
||||
# else:
|
||||
# return y
|
||||
# Depending on the sequence of calls, we might install two different sets of guards:
|
||||
|
||||
# 1. opt_f(False, y):
|
||||
# - "x == False" (always works for any size y)
|
||||
|
||||
# 2. opt_f(True, y):
|
||||
# - Triggers recompilation and results in guards like:
|
||||
# - "x == True and y.size(0) == 4"
|
||||
# - (or "y.size(0) == 4 and x == True")
|
||||
|
||||
# The order of checking the guards matters. In this specific example:
|
||||
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
|
||||
# we may have an unnessary shape speciliazation for y.
|
||||
def maybe_specialize_sym_int_with_hint(maybe_sym) -> int:
|
||||
assert isinstance(maybe_sym, (int, torch.SymInt))
|
||||
if is_symbolic(maybe_sym):
|
||||
assert maybe_sym.node.shape_env is not self, \
|
||||
"expect the symbol is created from an shape env other than current one."
|
||||
return maybe_sym.node.require_hint()
|
||||
return maybe_sym
|
||||
|
||||
ex_size = tuple(maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
|
||||
ex_stride = tuple(maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
|
||||
ex_storage_offset = maybe_specialize_sym_int_with_hint(ex.storage_offset())
|
||||
ex_size = tuple(self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size())
|
||||
ex_stride = tuple(self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride())
|
||||
ex_storage_offset = self._maybe_specialize_sym_int_with_hint(ex.storage_offset())
|
||||
|
||||
return self._create_symbolic_sizes_strides_storage_offset(
|
||||
ex_size,
|
||||
@ -2381,6 +2339,48 @@ class ShapeEnv:
|
||||
symbolic_context=symbolic_context,
|
||||
)
|
||||
|
||||
# Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic").
|
||||
# We create symbols in shape_env using the backed hints behind SymInt.
|
||||
|
||||
# Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape.
|
||||
# produce_guards will trigger specializations on the outer stuff
|
||||
|
||||
# Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint().
|
||||
#
|
||||
# It's probably good for now but it's important to note that this approach has implications for
|
||||
# the original shape_env when checking guards in different order.
|
||||
|
||||
# Example:
|
||||
# ---------
|
||||
# Consider a function "opt_f" as shown below:
|
||||
|
||||
# @torch.compile()
|
||||
# def opt_f(x: bool, y: Tensor):
|
||||
# if x == True:
|
||||
# return y + torch.randn([4])
|
||||
# else:
|
||||
# return y
|
||||
# Depending on the sequence of calls, we might install two different sets of guards:
|
||||
|
||||
# 1. opt_f(False, y):
|
||||
# - "x == False" (always works for any size y)
|
||||
|
||||
# 2. opt_f(True, y):
|
||||
# - Triggers recompilation and results in guards like:
|
||||
# - "x == True and y.size(0) == 4"
|
||||
# - (or "y.size(0) == 4 and x == True")
|
||||
|
||||
# The order of checking the guards matters. In this specific example:
|
||||
# If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True,
|
||||
# we may have an unnessary shape speciliazation for y.
|
||||
def _maybe_specialize_sym_int_with_hint(self, maybe_sym) -> int:
|
||||
assert isinstance(maybe_sym, (int, torch.SymInt))
|
||||
if is_symbolic(maybe_sym):
|
||||
assert maybe_sym.node.shape_env is not self, \
|
||||
"expect the symbol is created from an shape env other than current one."
|
||||
return maybe_sym.node.require_hint()
|
||||
return maybe_sym
|
||||
|
||||
@record_shapeenv_event()
|
||||
def _create_symbolic_sizes_strides_storage_offset(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user