mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support unbacked whitelist (#154295)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154295 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
ef4d57329b
commit
d865b784e4
@ -7921,6 +7921,34 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
@torch.compiler.config.patch(unbacked_sources="L['x']")
|
||||
def test_unbacked_sources_tensor(self):
|
||||
counter = CompileCounter()
|
||||
|
||||
@torch.compile(backend=counter)
|
||||
def fn(x):
|
||||
return x * x
|
||||
|
||||
fn(torch.randn(0))
|
||||
fn(torch.randn(1))
|
||||
fn(torch.randn(2))
|
||||
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
@torch.compiler.config.patch(unbacked_sources="L['x']")
|
||||
def test_unbacked_sources_scalar(self):
|
||||
counter = CompileCounter()
|
||||
|
||||
@torch.compile(backend=counter)
|
||||
def fn(x):
|
||||
return x * x
|
||||
|
||||
fn(0)
|
||||
fn(1)
|
||||
fn(2)
|
||||
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
|
||||
@torch.compiler.config.patch(dynamic_sources="L['x']")
|
||||
def test_dynamic_sources_graph_break(self):
|
||||
counter = CompileCounter()
|
||||
|
@ -1454,6 +1454,11 @@ class VariableBuilder:
|
||||
if value.dynamism is None or value.dynamism.type == _DimHintType.STATIC:
|
||||
return self.wrap_symint(value.val)
|
||||
elif value.dynamism.type == _DimHintType.DYNAMIC:
|
||||
log.debug(
|
||||
"%s marked %s via IntWrapper",
|
||||
self.source.name(),
|
||||
DimDynamic.DYNAMIC,
|
||||
)
|
||||
return self.wrap_symint(
|
||||
value.val,
|
||||
dynamism=DimDynamic.DYNAMIC,
|
||||
@ -1462,6 +1467,11 @@ class VariableBuilder:
|
||||
),
|
||||
)
|
||||
elif value.dynamism.type == _DimHintType.AUTO:
|
||||
log.debug(
|
||||
"%s marked %s via IntWrapper",
|
||||
self.source.name(),
|
||||
DimDynamic.DYNAMIC,
|
||||
)
|
||||
return self.wrap_symint(value.val, dynamism=DimDynamic.DYNAMIC)
|
||||
else:
|
||||
raise RuntimeError(f"Undefined dynamism {value.dynamism}")
|
||||
@ -1767,7 +1777,12 @@ class VariableBuilder:
|
||||
if type(value) is int:
|
||||
# allowlist has higher precedence over specialization control.
|
||||
if is_dynamic_source(self.source.name()):
|
||||
return self.wrap_symint(value, True)
|
||||
log.debug("%s marked dynamic via source whitelist", self.source.name())
|
||||
return self.wrap_symint(value, dynamism=DimDynamic.DYNAMIC)
|
||||
|
||||
if is_unbacked_source(self.source.name()):
|
||||
log.debug("%s marked unbacked via source whitelist", self.source.name())
|
||||
return self.wrap_symint(value, dynamism=DimDynamic.SIZE_LIKE_UNBACKED)
|
||||
|
||||
if not config.specialize_int:
|
||||
# unspecializing int by default, but still
|
||||
@ -2117,7 +2132,6 @@ class VariableBuilder:
|
||||
def wrap_symint(
|
||||
self,
|
||||
value,
|
||||
is_forced_allow_list_dynamic=False,
|
||||
dynamism: Optional[DimDynamic] = None,
|
||||
context: Optional[SymIntSymbolicContext] = None,
|
||||
):
|
||||
@ -2165,12 +2179,8 @@ class VariableBuilder:
|
||||
if isinstance(base_source, ChainedSource):
|
||||
base_source = base_source.get_base()
|
||||
|
||||
if dynamism == DimDynamic.DYNAMIC:
|
||||
log.debug("%s marked %s via IntWrapper", self.source.name(), dynamism)
|
||||
if dynamism is not None:
|
||||
dynamic_dim = dynamism
|
||||
elif is_forced_allow_list_dynamic:
|
||||
log.debug("%s marked dynamic via source whitelist", self.source.name())
|
||||
dynamic_dim = DimDynamic.DYNAMIC
|
||||
elif (
|
||||
config.automatic_dynamic_shapes
|
||||
and frame_state_entry.scalar is auto_dynamic
|
||||
@ -2963,6 +2973,43 @@ def record_automatic_dynamic(
|
||||
)
|
||||
|
||||
|
||||
_UNBACKED_SOURCES: Optional[set[str]] = None
|
||||
_UNBACKED_SOURCES_CONFIG_HASH: Optional[int] = None
|
||||
|
||||
|
||||
def get_unbacked_sources() -> set[str]:
|
||||
global _UNBACKED_SOURCES, _UNBACKED_SOURCES_CONFIG_HASH
|
||||
|
||||
current_hash = hash(torch.compiler.config.unbacked_sources)
|
||||
|
||||
# If we have already calculated the sources and the config hasn't changed, return cached result
|
||||
if _UNBACKED_SOURCES is not None and _UNBACKED_SOURCES_CONFIG_HASH == current_hash:
|
||||
return _UNBACKED_SOURCES
|
||||
|
||||
# Config has changed or first time, (re)calculate the sources
|
||||
_UNBACKED_SOURCES = {
|
||||
s
|
||||
for s in torch.compiler.config.unbacked_sources.replace(" ", "").split(",")
|
||||
if s
|
||||
}
|
||||
_UNBACKED_SOURCES_CONFIG_HASH = current_hash
|
||||
|
||||
return _UNBACKED_SOURCES
|
||||
|
||||
|
||||
def is_unbacked_source(source_name: str) -> bool:
|
||||
unbacked_sources = get_unbacked_sources()
|
||||
for pattern in unbacked_sources:
|
||||
if pattern == source_name or re.match(pattern, source_name):
|
||||
log.debug(
|
||||
"%s was marked unbacked due to unbacked source allowlist pattern: %s",
|
||||
source_name,
|
||||
pattern,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Performs automatic dynamic dim determination.
|
||||
# Returns a SymbolicContext
|
||||
def _automatic_dynamic(
|
||||
@ -3135,6 +3182,11 @@ def _automatic_dynamic(
|
||||
automatic_dynamic_size = True
|
||||
automatic_dynamic_stride = True
|
||||
|
||||
if is_unbacked_source(name):
|
||||
log.debug("%s marked unbacked via source whitelist", name)
|
||||
automatic_dynamic_size = True
|
||||
automatic_dynamic_stride = True
|
||||
|
||||
automatic_dynamic = automatic_dynamic_size or automatic_dynamic_stride
|
||||
|
||||
# We will process constraints first, as they will imply that we
|
||||
@ -3185,7 +3237,7 @@ def _automatic_dynamic(
|
||||
constraint_sizes.append(constraint_size)
|
||||
constraint_strides.append(constraint_stride)
|
||||
|
||||
if marked_unbacked:
|
||||
if marked_unbacked or is_unbacked_source(name):
|
||||
dynamic_size = DimDynamic.SIZE_LIKE_UNBACKED
|
||||
elif (
|
||||
constraint_size is not None
|
||||
|
@ -83,5 +83,15 @@ this to a string that identifies the shared profile. This is useful if you want
|
||||
for models that are not identical, but are similar enough to share PGO profiles.
|
||||
"""
|
||||
|
||||
unbacked_sources: str = Config(
|
||||
env_name_default="TORCH_COMPILE_UNBACKED_SOURCES", default=""
|
||||
)
|
||||
"""
|
||||
Comma delimited list of sources that should be marked as unbacked. Primarily useful for large
|
||||
models with graph breaks where you need intermediate tensors marked unbacked.
|
||||
|
||||
This whitelist is dominant over all other flags dynamic=False, force_nn_module_property_static_shapes
|
||||
and force_parameter_static_shapes.
|
||||
"""
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
Reference in New Issue
Block a user