[dynamic shapes] DynamicInts prototype (#162194)

Initial prototype for dynamic int inputs, allows users to run with `torch.compile(f)(DynamicInt(4))`, compiling dynamically and using the underlying hint at runtime.

Current behavior:
- Also works in eager (mostly by subclassing int), as scalar input to torch functions, or numpy/math/etc. For example, `x = DynamicInt(3); torch.randn(x); torch.add(y, z, alpha=x); np.arange(x)` all act as if x = 3.
- Behavior for arithmetic ops is to return new DynamicInts rather than static ints; `DynamicInt(3) * 2 = DynamicInt(6)`. This is via SymNode magic methods, but coverage might not be 100% - for example, I had to explicitly override floordiv to avoid int casting. This is not necessarily the case for non-magic method ops (e.g. `math.cos(x)`). The alternative here is to int cast on all operations, but I opted for this for dynamism propagation in non-compiled regions.
- Doesn't ban fullgraph=False; DynamicInt objects might be leaked back to the user, but I guess this is fine, because they can be casted to ints when needed?
- Dynamo only allocates one symbol per DynamicInt; specifying the same DynamicInt for multiple inputs leads to input deduplication, and a guard installed.
- We don't raise on int specialization (in allowlist/maybe_mark_dynamic style) - but an easy change if needed.
- DynamicInts as nn.Module attributes are handled.
- We don't guard on the DynamicInt id, e.g. users can do the following without recompiling (maybe we should guard?)
```python
x = DynamicInt(4)
f(x)
f(1)
f(DynamicInt(3))  # same as f(3)
```

Follow-up work:
- Specifying shape constraints, either at the int-level, e.g.
```python
DynamicInt(64, name="s0", constraints=["s0 % 32 == 0", "s0 <= 1024"]
```
or at the compilation level, e.g. something like
```python
s0 = DynamicInt(64, name="s0")
s1 = DynamicInt(128, name="s1")
with some_compiler_config.dynamic_int_constraints(["s1 == 2*s0", "s0 % 32 == 0"]):
    f(s0, s1)
```
This should subsume the need for specifying derived SymInts?
- SymFloat support - currently it seems backed floats are specialized by the tensorify float pass, and there's no handling in inductor.
- Propagating dynamism in tensor constructors, e.g. `x = DynamicInt(4); torch.randn(x)` could annotate `_dynamo_dynamic_indices`.

Differential Revision: D81698719

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162194
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Pian Pawakapan
2025-09-18 23:26:28 +00:00
committed by PyTorch MergeBot
parent f4eca0e3b3
commit 4c007073e6
11 changed files with 293 additions and 15 deletions

View File

@ -136,6 +136,7 @@ from .source import (
DefaultsSource,
DictGetItemSource,
DictSubclassGetItemSource,
DynamicScalarSource,
FlattenScriptObjectSource,
FloatTensorSource,
FSDPNNModuleSource,
@ -1719,6 +1720,14 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, DynamicScalarSource):
assert base_guard_manager
out = base_guard_manager.lambda_manager(
python_lambda=lambda x: int(x),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
else:
raise AssertionError(
f"missing guard manager builder {source} - {source.name()}"