Support unbacked whitelist (#154295)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154295
Approved by: https://github.com/angelayi
This commit is contained in:
bobrenjc93
2025-05-28 08:21:00 -07:00
committed by PyTorch MergeBot
parent ef4d57329b
commit d865b784e4
3 changed files with 98 additions and 8 deletions

View File

@ -7921,6 +7921,34 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(counter.frame_count, 1)
@torch.compiler.config.patch(unbacked_sources="L['x']")
def test_unbacked_sources_tensor(self):
counter = CompileCounter()
@torch.compile(backend=counter)
def fn(x):
return x * x
fn(torch.randn(0))
fn(torch.randn(1))
fn(torch.randn(2))
self.assertEqual(counter.frame_count, 1)
@torch.compiler.config.patch(unbacked_sources="L['x']")
def test_unbacked_sources_scalar(self):
counter = CompileCounter()
@torch.compile(backend=counter)
def fn(x):
return x * x
fn(0)
fn(1)
fn(2)
self.assertEqual(counter.frame_count, 1)
@torch.compiler.config.patch(dynamic_sources="L['x']")
def test_dynamic_sources_graph_break(self):
counter = CompileCounter()

View File

@ -1454,6 +1454,11 @@ class VariableBuilder:
if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC:
return self.wrap_symint(value.val)
elif value.dynamism.type == _DimHintType.DYNAMIC:
log.debug(
"%s marked %s via IntWrapper",
self.source.name(),
DimDynamic.DYNAMIC,
)
return self.wrap_symint(
value.val,
dynamism=DimDynamic.DYNAMIC,
@ -1462,6 +1467,11 @@ class VariableBuilder:
),
)
elif value.dynamism.type == _DimHintType.AUTO:
log.debug(
"%s marked %s via IntWrapper",
self.source.name(),
DimDynamic.DYNAMIC,
)
return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC)
else:
raise RuntimeError(f"Undefined dynamism {value.dynamism}")
@ -1767,7 +1777,12 @@ class VariableBuilder:
if type(value) is int:
# allowlist has higher precedence over specialization control.
if is_dynamic_source(self.source.name()):
return self.wrap_symint(value, True)
log.debug("%s marked dynamic via source whitelist", self.source.name())
return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC)
if is_unbacked_source(self.source.name()):
log.debug("%s marked unbacked via source whitelist", self.source.name())
return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED)
if not config.specialize_int:
# unspecializing int by default, but still
@ -2117,7 +2132,6 @@ class VariableBuilder:
def wrap_symint(
self,
value,
is_forced_allow_list_dynamic=False,
dynamism: Optional[DimDynamic] = None,
context: Optional[SymIntSymbolicContext] = None,
):
@ -2165,12 +2179,8 @@ class VariableBuilder:
if isinstance(base_source, ChainedSource):
base_source = base_source.get_base()
if dynamism == DimDynamic.DYNAMIC:
log.debug("%s marked %s via IntWrapper", self.source.name(), dynamism)
if dynamism is not None:
dynamic_dim = dynamism
elif is_forced_allow_list_dynamic:
log.debug("%s marked dynamic via source whitelist", self.source.name())
dynamic_dim = DimDynamic.DYNAMIC
elif (
config.automatic_dynamic_shapes
and frame_state_entry.scalar is auto_dynamic
@ -2963,6 +2973,43 @@ def record_automatic_dynamic(
)
_UNBACKED_SOURCES: Optional[set[str]] = None
_UNBACKED_SOURCES_CONFIG_HASH: Optional[int] = None
def get_unbacked_sources() -> set[str]:
global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH
current_hash = hash(torch.compiler.config.unbacked_sources)
# If we have already calculated the sources and the config hasn't changed, return cached result
if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash:
return _UNBACKED_SOURCES
# Config has changed or first time, (re)calculate the sources
_UNBACKED_SOURCES = {
s
for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",")
if s
}
_UNBACKED_SOURCES_CONFIG_HASH = current_hash
return _UNBACKED_SOURCES
def is_unbacked_source(source_name: str) -> bool:
unbacked_sources = get_unbacked_sources()
for pattern in unbacked_sources:
if pattern == source_name or re.match(pattern, source_name):
log.debug(
"%s was marked unbacked due to unbacked source allowlist pattern: %s",
source_name,
pattern,
)
return True
return False
# Performs automatic dynamic dim determination.
# Returns a SymbolicContext
def _automatic_dynamic(
@ -3135,6 +3182,11 @@ def _automatic_dynamic(
automatic_dynamic_size = True
automatic_dynamic_stride = True
if is_unbacked_source(name):
log.debug("%s marked unbacked via source whitelist", name)
automatic_dynamic_size = True
automatic_dynamic_stride = True
automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride
# We will process constraints first, as they will imply that we
@ -3185,7 +3237,7 @@ def _automatic_dynamic(
constraint_sizes.append(constraint_size)
constraint_strides.append(constraint_stride)
if marked_unbacked:
if marked_unbacked or is_unbacked_source(name):
dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED
elif (
constraint_size is not None

View File

@ -83,5 +83,15 @@ this to a string that identifies the shared profile. This is useful if you want
for models that are not identical, but are similar enough to share PGO profiles.
"""
unbacked_sources: str = Config(
env_name_default="TORCH_COMPILE_UNBACKED_SOURCES", default=""
)
"""
Comma delimited list of sources that should be marked as unbacked. Primarily useful for large
models with graph breaks where you need intermediate tensors marked unbacked.
This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes
and force_parameter_static_shapes.
"""
install_config_module(sys.modules[__name__])