mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Introduce EphemeralSource for symbols that should be simplified out (#120948)
Context: view fake-ification should handle closed-over state in ViewFuncs for use in view replay by: * fake-ifying tensors * symbolicizing SymInts This avoids invalid specialization during view replay. However, the symbols / tensors created as intermediates in the view chain should not stick around or be guarded on. This PR introduces an `EphemeralSource` intended to be used as a source for this purpose. It has the following properties: * Considered first to be simplified out in symbol simplification logic * Errors if guarded on Differential Revision: [D54561597](https://our.internmc.facebook.com/intern/diff/D54561597) Pull Request resolved: https://github.com/pytorch/pytorch/pull/120948 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
d968fc442b
commit
dad1b76584
@ -130,15 +130,18 @@ class FakeSymbolicTensor(torch.Tensor):
|
||||
raise RuntimeError(f"operator {func_overload} not supported")
|
||||
|
||||
|
||||
def create_symbolic_tensor(name, arg, shape_env):
|
||||
def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
|
||||
from torch._dynamo.source import ConstantSource
|
||||
|
||||
if source is None:
|
||||
source = ConstantSource(name)
|
||||
constraint_dims = [None] * arg.dim()
|
||||
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
|
||||
if dynamic_dims is None:
|
||||
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
|
||||
sym_shapes, sym_strides, sym_storage_offset = \
|
||||
shape_env.create_symbolic_sizes_strides_storage_offset(
|
||||
arg,
|
||||
source=ConstantSource(name),
|
||||
source=source,
|
||||
symbolic_context=StatelessSymbolicContext(
|
||||
dynamic_sizes=dynamic_dims,
|
||||
constraint_sizes=constraint_dims
|
||||
@ -749,6 +752,87 @@ class f(torch.nn.Module):
|
||||
# No guards should be generated
|
||||
self.assertEqual(len(shape_env.guards), 0)
|
||||
|
||||
def test_ephemeral_source_simplification(self):
|
||||
from torch._dynamo.source import EphemeralSource
|
||||
|
||||
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
|
||||
# of construction order or check order.
|
||||
for construct_ephemeral_first, x_first_in_check in (
|
||||
itertools.product([False, True], [False, True])
|
||||
):
|
||||
shape_env = ShapeEnv()
|
||||
shape = (5, 10)
|
||||
dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
|
||||
x = create_symbolic_tensor(
|
||||
"x",
|
||||
torch.randn(*shape),
|
||||
shape_env,
|
||||
source=(EphemeralSource() if construct_ephemeral_first else None),
|
||||
dynamic_dims=dynamic_dims,
|
||||
)
|
||||
y = create_symbolic_tensor(
|
||||
"y",
|
||||
torch.randn(*shape),
|
||||
shape_env,
|
||||
source=(EphemeralSource() if not construct_ephemeral_first else None),
|
||||
dynamic_dims=dynamic_dims,
|
||||
)
|
||||
t_with_ephemeral = x if construct_ephemeral_first else y
|
||||
|
||||
def _get_ephemeral_source_symbols(t):
|
||||
return [
|
||||
s.node.expr for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
|
||||
if isinstance(s, torch.SymInt) and s.node.expr in shape_env.var_to_sources
|
||||
and any(
|
||||
source.is_ephemeral() for source in shape_env.var_to_sources[s.node.expr]
|
||||
)
|
||||
]
|
||||
|
||||
# these checks should simplify out the ephemeral symbols, regardless of the
|
||||
# ordering x == y or y == x
|
||||
self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
|
||||
if x_first_in_check:
|
||||
torch._check(x.size() == y.size())
|
||||
torch._check(x.stride() == y.stride())
|
||||
torch._check(x.storage_offset() == y.storage_offset())
|
||||
else:
|
||||
torch._check(y.size() == x.size())
|
||||
torch._check(y.stride() == x.stride())
|
||||
torch._check(y.storage_offset() == x.storage_offset())
|
||||
self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
|
||||
|
||||
def test_ephemeral_source_unified_with_non_ephemeral_source(self):
|
||||
from torch._dynamo.source import EphemeralSource
|
||||
|
||||
for construct_ephemeral_first in (False, True):
|
||||
shape_env = ShapeEnv()
|
||||
shape = (5, 10)
|
||||
# use duck sizing here to ensure symbol reuse across x and y
|
||||
duck_dims = [DimDynamic.DUCK for _ in shape]
|
||||
x = create_symbolic_tensor(
|
||||
"x",
|
||||
torch.randn(*shape),
|
||||
shape_env,
|
||||
source=(EphemeralSource() if construct_ephemeral_first else None),
|
||||
dynamic_dims=duck_dims,
|
||||
)
|
||||
y = create_symbolic_tensor(
|
||||
"y",
|
||||
torch.randn(*shape),
|
||||
shape_env,
|
||||
source=(EphemeralSource() if not construct_ephemeral_first else None),
|
||||
dynamic_dims=duck_dims,
|
||||
)
|
||||
|
||||
# regardless of construction order, non-ephemeral sources should be preferred
|
||||
# first in the var_to_sources list for potential guarding later on
|
||||
for source_list in shape_env.var_to_sources.values():
|
||||
self.assertFalse(source_list[0].is_ephemeral())
|
||||
|
||||
self.assertEqual(x.size(), y.size())
|
||||
self.assertEqual(x.stride(), y.stride())
|
||||
self.assertEqual(x.storage_offset(), y.storage_offset())
|
||||
|
||||
|
||||
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
|
||||
class TestSymNumberMagicMethods(TestCase):
|
||||
|
||||
@ -162,6 +162,32 @@ class ParamBufferSource(AttrSource):
|
||||
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
|
||||
|
||||
|
||||
# This source is intended to be used in places where a source is needed but it is expected
|
||||
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
|
||||
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
|
||||
# source. Guarding on this source is an error.
|
||||
#
|
||||
# Example: During subclass view fake-ification, any close-over ViewFunc state should be
|
||||
# symbolicized / fake-ified to avoid invalid specialization during view replay. This source
|
||||
# is useful for symbols utilized in the middle of the view chain that are not expected to be
|
||||
# present within the final view shape metadata.
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class EphemeralSource(Source):
|
||||
desc: Optional[str] = None
|
||||
|
||||
def guard_source(self):
|
||||
return GuardSource.EPHEMERAL
|
||||
|
||||
def name(self):
|
||||
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
||||
|
||||
def make_guard(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_ephemeral(self):
|
||||
return True
|
||||
|
||||
|
||||
class TensorProperty(enum.Enum):
|
||||
SIZE = 0
|
||||
STRIDE = 1
|
||||
|
||||
@ -87,6 +87,7 @@ class GuardSource(enum.Enum):
|
||||
LOCAL_FSDP_MODULE = 7
|
||||
GLOBAL_FSDP_MODULE = 8
|
||||
BACKWARD_STATE = 9
|
||||
EPHEMERAL = 10
|
||||
|
||||
def is_fsdp_module(self) -> bool:
|
||||
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
|
||||
@ -776,6 +777,9 @@ class Source:
|
||||
def is_dict_key(self):
|
||||
return False
|
||||
|
||||
def is_ephemeral(self):
|
||||
return False
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -803,6 +807,9 @@ class ChainedSource(Source):
|
||||
# Recurse until you either hit a ConstDictKey or a Source
|
||||
return self.base.is_dict_key()
|
||||
|
||||
def is_ephemeral(self):
|
||||
return self.base.is_ephemeral()
|
||||
|
||||
|
||||
def detect_fake_mode(inputs: Any = None):
|
||||
"""
|
||||
|
||||
@ -2688,7 +2688,12 @@ class ShapeEnv:
|
||||
self.log.debug("create_symbol %s duck sized %s", r, source.name())
|
||||
|
||||
if isinstance(r, sympy.Symbol):
|
||||
self.var_to_sources[r].append(source)
|
||||
r_sources = self.var_to_sources[r]
|
||||
r_sources.append(source)
|
||||
if not source.is_ephemeral() and r_sources[0].is_ephemeral():
|
||||
# prefer non-ephemeral source first since it may be guarded on later
|
||||
r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0]
|
||||
|
||||
# This ensures we get zeros in symbol_guard_counts, which makes
|
||||
# some queries simpler (since we will accumulate mass on 0 this
|
||||
# way)
|
||||
@ -3796,8 +3801,21 @@ class ShapeEnv:
|
||||
# In case of really gnarly expression, we don't blow up
|
||||
if len(free) > 5:
|
||||
return
|
||||
# NB: prioritize unbacked symints for solving by ordering them last
|
||||
free = sorted(free, key=lambda x: (self.size_hint(x, allow_none=True) or sys.maxsize, x.name), reverse=True) # type: ignore[attr-defined]
|
||||
|
||||
# Prioritize unbacked symints for solving by ordering them last.
|
||||
# Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3).
|
||||
# (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols)
|
||||
# Prefer to simplify out symbols with ephemeral sources.
|
||||
def _smart_symbol_sort(x):
|
||||
has_only_ephemeral_sources = (
|
||||
x in self.var_to_sources and all(s.is_ephemeral() for s in self.var_to_sources[x])
|
||||
)
|
||||
size = self.size_hint(x, allow_none=True) or sys.maxsize
|
||||
name = x.name
|
||||
# 1 puts ephemeral sourced symbols first when sorting in reverse
|
||||
return (1 if has_only_ephemeral_sources else 0, size, name)
|
||||
|
||||
free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined]
|
||||
lhs = expr.lhs
|
||||
rhs = expr.rhs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user