Introduce EphemeralSource for symbols that should be simplified out (#120948)

Context: view fake-ification should handle closed-over state in ViewFuncs for use in view replay by:
* fake-ifying tensors
* symbolicizing SymInts

This avoids invalid specialization during view replay. However, the symbols / tensors created as intermediates in the view chain should not stick around or be guarded on. This PR introduces an `EphemeralSource` intended to be used as a source for this purpose. It has the following properties:
* Considered first to be simplified out in symbol simplification logic
* Errors if guarded on

Differential Revision: [D54561597](https://our.internmc.facebook.com/intern/diff/D54561597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120948
Approved by: https://github.com/ezyang
This commit is contained in:
Joel Schlosser
2024-03-05 18:14:18 -05:00
committed by PyTorch MergeBot
parent d968fc442b
commit dad1b76584
4 changed files with 141 additions and 6 deletions

View File

@ -130,15 +130,18 @@ class FakeSymbolicTensor(torch.Tensor):
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env):
def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
from torch._dynamo.source import ConstantSource
if source is None:
source = ConstantSource(name)
constraint_dims = [None] * arg.dim()
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
if dynamic_dims is None:
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
sym_shapes, sym_strides, sym_storage_offset = \
shape_env.create_symbolic_sizes_strides_storage_offset(
arg,
source=ConstantSource(name),
source=source,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_dims,
constraint_sizes=constraint_dims
@ -749,6 +752,87 @@ class f(torch.nn.Module):
# No guards should be generated
self.assertEqual(len(shape_env.guards), 0)
def test_ephemeral_source_simplification(self):
from torch._dynamo.source import EphemeralSource
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
# of construction order or check order.
for construct_ephemeral_first, x_first_in_check in (
itertools.product([False, True], [False, True])
):
shape_env = ShapeEnv()
shape = (5, 10)
dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
x = create_symbolic_tensor(
"x",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if construct_ephemeral_first else None),
dynamic_dims=dynamic_dims,
)
y = create_symbolic_tensor(
"y",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if not construct_ephemeral_first else None),
dynamic_dims=dynamic_dims,
)
t_with_ephemeral = x if construct_ephemeral_first else y
def _get_ephemeral_source_symbols(t):
return [
s.node.expr for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
if isinstance(s, torch.SymInt) and s.node.expr in shape_env.var_to_sources
and any(
source.is_ephemeral() for source in shape_env.var_to_sources[s.node.expr]
)
]
# these checks should simplify out the ephemeral symbols, regardless of the
# ordering x == y or y == x
self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
if x_first_in_check:
torch._check(x.size() == y.size())
torch._check(x.stride() == y.stride())
torch._check(x.storage_offset() == y.storage_offset())
else:
torch._check(y.size() == x.size())
torch._check(y.stride() == x.stride())
torch._check(y.storage_offset() == x.storage_offset())
self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
def test_ephemeral_source_unified_with_non_ephemeral_source(self):
from torch._dynamo.source import EphemeralSource
for construct_ephemeral_first in (False, True):
shape_env = ShapeEnv()
shape = (5, 10)
# use duck sizing here to ensure symbol reuse across x and y
duck_dims = [DimDynamic.DUCK for _ in shape]
x = create_symbolic_tensor(
"x",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if construct_ephemeral_first else None),
dynamic_dims=duck_dims,
)
y = create_symbolic_tensor(
"y",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if not construct_ephemeral_first else None),
dynamic_dims=duck_dims,
)
# regardless of construction order, non-ephemeral sources should be preferred
# first in the var_to_sources list for potential guarding later on
for source_list in shape_env.var_to_sources.values():
self.assertFalse(source_list[0].is_ephemeral())
self.assertEqual(x.size(), y.size())
self.assertEqual(x.stride(), y.stride())
self.assertEqual(x.storage_offset(), y.storage_offset())
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestSymNumberMagicMethods(TestCase):

View File

@ -162,6 +162,32 @@ class ParamBufferSource(AttrSource):
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
# This source is intended to be used in places where a source is needed but it is expected
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
# source. Guarding on this source is an error.
#
# Example: During subclass view fake-ification, any close-over ViewFunc state should be
# symbolicized / fake-ified to avoid invalid specialization during view replay. This source
# is useful for symbols utilized in the middle of the view chain that are not expected to be
# present within the final view shape metadata.
@dataclasses.dataclass(frozen=True)
class EphemeralSource(Source):
desc: Optional[str] = None
def guard_source(self):
return GuardSource.EPHEMERAL
def name(self):
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
def make_guard(self):
raise NotImplementedError()
def is_ephemeral(self):
return True
class TensorProperty(enum.Enum):
SIZE = 0
STRIDE = 1

View File

@ -87,6 +87,7 @@ class GuardSource(enum.Enum):
LOCAL_FSDP_MODULE = 7
GLOBAL_FSDP_MODULE = 8
BACKWARD_STATE = 9
EPHEMERAL = 10
def is_fsdp_module(self) -> bool:
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
@ -776,6 +777,9 @@ class Source:
def is_dict_key(self):
return False
def is_ephemeral(self):
return False
def reconstruct(self, codegen):
raise NotImplementedError()
@ -803,6 +807,9 @@ class ChainedSource(Source):
# Recurse until you either hit a ConstDictKey or a Source
return self.base.is_dict_key()
def is_ephemeral(self):
return self.base.is_ephemeral()
def detect_fake_mode(inputs: Any = None):
"""

View File

@ -2688,7 +2688,12 @@ class ShapeEnv:
self.log.debug("create_symbol %s duck sized %s", r, source.name())
if isinstance(r, sympy.Symbol):
self.var_to_sources[r].append(source)
r_sources = self.var_to_sources[r]
r_sources.append(source)
if not source.is_ephemeral() and r_sources[0].is_ephemeral():
# prefer non-ephemeral source first since it may be guarded on later
r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0]
# This ensures we get zeros in symbol_guard_counts, which makes
# some queries simpler (since we will accumulate mass on 0 this
# way)
@ -3796,8 +3801,21 @@ class ShapeEnv:
# In case of really gnarly expression, we don't blow up
if len(free) > 5:
return
# NB: prioritize unbacked symints for solving by ordering them last
free = sorted(free, key=lambda x: (self.size_hint(x, allow_none=True) or sys.maxsize, x.name), reverse=True) # type: ignore[attr-defined]
# Prioritize unbacked symints for solving by ordering them last.
# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
# Prefer to simplify out symbols with ephemeral sources.
def _smart_symbol_sort(x):
has_only_ephemeral_sources = (
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
)
size = self.size_hint(x, allow_none=True) or sys.maxsize
name = x.name
# 1 puts ephemeral sourced symbols first when sorting in reverse
return (1 if has_only_ephemeral_sources else 0, size, name)
free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined]
lhs = expr.lhs
rhs = expr.rhs