[dynamic shapes] DynamicInts prototype (#162194)

Initial prototype for dynamic int inputs, allows users to run with `torch.compile(f)(DynamicInt(4))`, compiling dynamically and using the underlying hint at runtime.

Current behavior:
- Also works in eager (mostly by subclassing int), as scalar input to torch functions, or numpy/math/etc. For example, `x = DynamicInt(3); torch.randn(x); torch.add(y, z, alpha=x); np.arange(x)` all act as if x = 3.
- Behavior for arithmetic ops is to return new DynamicInts rather than static ints; `DynamicInt(3) * 2 = DynamicInt(6)`. This is via SymNode magic methods, but coverage might not be 100% - for example, I had to explicitly override floordiv to avoid int casting. This is not necessarily the case for non-magic method ops (e.g. `math.cos(x)`). The alternative here is to int cast on all operations, but I opted for this for dynamism propagation in non-compiled regions.
- Doesn't ban fullgraph=False; DynamicInt objects might be leaked back to the user, but I guess this is fine, because they can be casted to ints when needed?
- Dynamo only allocates one symbol per DynamicInt; specifying the same DynamicInt for multiple inputs leads to input deduplication, and a guard installed.
- We don't raise on int specialization (in allowlist/maybe_mark_dynamic style) - but an easy change if needed.
- DynamicInts as nn.Module attributes are handled.
- We don't guard on the DynamicInt id, e.g. users can do the following without recompiling (maybe we should guard?)
```python
x = DynamicInt(4)
f(x)
f(1)
f(DynamicInt(3))  # same as f(3)
```

Follow-up work:
- Specifying shape constraints, either at the int-level, e.g.
```python
DynamicInt(64, name="s0", constraints=["s0 % 32 == 0", "s0 <= 1024"]
```
or at the compilation level, e.g. something like
```python
s0 = DynamicInt(64, name="s0")
s1 = DynamicInt(128, name="s1")
with some_compiler_config.dynamic_int_constraints(["s1 == 2*s0", "s0 % 32 == 0"]):
    f(s0, s1)
```
This should subsume the need for specifying derived SymInts?
- SymFloat support - currently it seems backed floats are specialized by the tensorify float pass, and there's no handling in inductor.
- Propagating dynamism in tensor constructors, e.g. `x = DynamicInt(4); torch.randn(x)` could annotate `_dynamo_dynamic_indices`.

Differential Revision: D81698719

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162194
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Pian Pawakapan
2025-09-18 23:26:28 +00:00
committed by PyTorch MergeBot
parent f4eca0e3b3
commit 4c007073e6
11 changed files with 293 additions and 15 deletions

View File

@ -8,6 +8,10 @@
These APIs are experimental and subject to change without notice.
:::
```{eval-rst}
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
```
## torch.fx.experimental.symbolic_shapes
```{eval-rst}

View File

@ -1818,6 +1818,96 @@ class TestSymNumberMagicMethods(TestCase):
self.assertTrue(isinstance(s3, int))
self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
@fresh_cache()
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_dynamic_int_basic_compile(self, backend):
from torch.fx.experimental.sym_node import DynamicInt
cnt = CompileCounterWithBackend(backend)
# test scalar inputs to function
def f(x, y, z):
out = torch.tensor([x + y + z])
out = out + torch.zeros(abs(x) + 2).sum() # test out tensor construction
return out
fn = torch.compile(f, fullgraph=True, backend=cnt)
x = DynamicInt(1)
z = DynamicInt(3)
self.assertEqual(fn(x, x, z), f(1, 1, 3)) # guard: x == y
self.assertEqual(fn(2, 2, 0), f(2, 2, 0))
self.assertEqual(fn(-1, -1, 2), f(-1, -1, 2))
self.assertEqual(cnt.frame_count, 1) # no recompiles
self.assertEqual(fn(3, 4, 5), f(3, 4, 5)) # now we recompile
self.assertEqual(cnt.frame_count, 2)
# test nn module property
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.i = DynamicInt(1)
def forward(self, x):
return torch.tensor([x + self.i])
cnt.clear()
m = Foo()
mc = torch.compile(m, backend=cnt, fullgraph=True)
self.assertEqual(mc(DynamicInt(0)), m(0))
mc.i = -2 # override attribute
self.assertEqual(mc(-1), m(-1))
self.assertEqual(cnt.frame_count, 1)
def test_dynamic_int_eager_usage(self):
from torch.fx.experimental.sym_node import DynamicInt
w = DynamicInt(-1)
x = DynamicInt(0)
y = DynamicInt(1)
z = DynamicInt(2)
def check(l, r):
self.assertTrue(isinstance(l, DynamicInt))
self.assertEqual(l, r)
# test arithmetic
check(2 * y + z, 4)
check((10 - z) // 2, 4)
check(1 // z, 0)
check(-w + w**2, 2)
check(x % z, 0)
check(1 << z, 4)
check(z | y, 3)
check(min(y, z), 1)
self.assertTrue(z > -2)
with self.assertRaises(ZeroDivisionError):
y % x
# math, numpy
self.assertEqual(math.cos(x), y)
self.assertEqual(math.prod([z, z], start=z), 8)
self.assertEqual(np.arange(z)[y], 1)
self.assertTrue(np.allclose(np.ones([y, z]).sum(axis=x), np.ones(z)))
# test conversions
self.assertTrue(isinstance(x + 2, int))
self.assertTrue(isinstance(x + 2, DynamicInt))
self.assertEqual(y / 2.0, 0.5) # this could return DynamicFloat in future
self.assertEqual(float(z), 2.0)
self.assertFalse(bool(x))
self.assertEqual(DynamicInt(x).real, x.real)
# torch functions, scalar inputs
self.assertEqual(torch.arange(z)[:w][x], 0)
self.assertEqual(torch.add(torch.tensor(w), torch.tensor(w), alpha=z), -3)
self.assertEqual(
list(torch.nn.Linear(z, y)(torch.randn(z * 2, z)).shape), [4, 1]
)
self.assertEqual(z * torch.ones(z).sum(dim=x), 4)
instantiate_parametrized_tests(TestSymNumberMagicMethods)

View File

@ -136,6 +136,7 @@ from .source import (
DefaultsSource,
DictGetItemSource,
DictSubclassGetItemSource,
DynamicScalarSource,
FlattenScriptObjectSource,
FloatTensorSource,
FSDPNNModuleSource,
@ -1719,6 +1720,14 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, DynamicScalarSource):
assert base_guard_manager
out = base_guard_manager.lambda_manager(
python_lambda=lambda x: int(x),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
else:
raise AssertionError(
f"missing guard manager builder {source} - {source.name()}"

View File

@ -2698,6 +2698,9 @@ class SubgraphTracer(fx.Tracer):
# tracer is the current tracer that's readily accessible in current tracer's graph.
self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}
# Maps _DynamicScalar object ids to allocated SymInt nodes, for symbol reuse
self.dynamic_scalar_nodes: dict[int, torch.SymInt] = {}
self.prev_inst = None
# True if this tracer is currently tracing into torch.utils.checkpoint
# as part of speculate_subgraph.

View File

@ -526,6 +526,29 @@ class ConvertIntSource(ChainedSource):
return f"cast_symbool_to_symint_guardless({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class DynamicScalarSource(ChainedSource):
is_int: bool
def __post_init__(self) -> None:
assert self.base is not None
def reconstruct(self, codegen: "PyCodegen") -> None:
# Integer casting at reconstruction helps reduce the amount of DynamicInts returned
# to the user, in favor of plain ints.
# For example, a compiled region that only does int arithmetic could return a
# DynamicInt without the casting here.
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "int"))
codegen(self.base)
codegen.extend_output(create_call_function(1, False))
def guard_source(self) -> GuardSource:
return self.base.guard_source()
def name(self) -> str:
return f"int({self.base.name()})"
@dataclasses.dataclass(frozen=True)
class FlattenScriptObjectSource(ChainedSource):
def __post_init__(self) -> None:

View File

@ -60,6 +60,7 @@ from torch._subclasses.meta_utils import is_sparse_any, safe_grad
from torch._utils_internal import justknobs_check
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental._dynamism import normalize_source_name
from torch.fx.experimental.sym_node import _DynamicScalar, DynamicInt
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
_nested_int_aware_sort,
@ -101,6 +102,7 @@ from ..source import (
ConvertIntSource,
DictGetItemSource,
DictSubclassGetItemSource,
DynamicScalarSource,
FloatTensorSource,
GetItemSource,
GradSource,
@ -456,7 +458,9 @@ class VariableBuilder:
# should NOT track them. If we use a single SymNodeVariable instance to track them
# across multiple uses, then guards created for one usage will incorrectly apply to
# all other usages of that constant, leading to unnecessary recompilations.
return is_torch_sym(value) and isinstance(vt, SymNodeVariable)
return (
is_torch_sym(value) or isinstance(value, _DynamicScalar)
) and isinstance(vt, SymNodeVariable)
if (
(
@ -1103,6 +1107,46 @@ class VariableBuilder:
):
self.install_guards(GuardBuilder.FUNCTION_MATCH)
return ItertoolsVariable(value, source=self.source)
elif isinstance(value, _DynamicScalar):
is_int = isinstance(value, DynamicInt)
source = DynamicScalarSource(self.source, is_int)
if id(value) in self.tx.output.root_tracer.dynamic_scalar_nodes:
# If we've already seen this dynamic scalar, reuse the existing
# SymInt/SymFloat node.
node = self.tx.output.root_tracer.dynamic_scalar_nodes[id(value)]
else:
sym = self.tx.output.shape_env.create_unspecified_symbol(
value.real,
source=source,
dynamic_dim=DimDynamic.DYNAMIC,
)
node = self.tx.output.shape_env.create_symintnode(
sym,
hint=value.real,
source=source,
)
# Bind to graph input
sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
type(node),
node,
source=source,
)
sym_node_proxy.node.meta["grapharg"] = GraphArg(
source,
node,
False,
None,
is_tensor=False,
example_strong_ref=node,
)
sym_expr = node.node.expr
assert isinstance(sym_expr, sympy.Symbol), (
f"{sym_expr} is not a basic Symbol."
)
self.tx.output.tracked_fakes.append(TrackedFake(node, source, None))
return SymNodeVariable(sym_node_proxy, node)
elif is_torch_sym(value):
# Note: this doesn't handle nested symints.
# For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo.

View File

@ -936,6 +936,9 @@ static bool is_int_or_symint(PyObject* obj) {
if (torch::is_symint(py::handle(obj))) {
return true;
}
if (torch::is_dynint(py::handle(obj))) {
return true;
}
// FakeTensor(..., size=()) is qualified for SymInt param,
// but we can't go via __index__ (below) as we would normally
@ -1070,7 +1073,8 @@ auto FunctionParameter::_check(
return !var.requires_grad() && var.dim() == 0;
}
if (torch::is_symfloat(py::handle(obj)) ||
torch::is_symint(py::handle(obj))) {
torch::is_symint(py::handle(obj)) ||
torch::is_dynint(py::handle(obj))) {
// This will induce a guard
return true;
}
@ -1085,7 +1089,8 @@ auto FunctionParameter::_check(
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) &&
!var.requires_grad() && var.dim() == 0;
}
if (torch::is_symint(py::handle(obj))) {
if (torch::is_symint(py::handle(obj)) ||
torch::is_dynint(py::handle(obj))) {
// This will induce a guard
return true;
}
@ -1127,7 +1132,8 @@ auto FunctionParameter::_check(
// Allow symint to be passed in as device, but we'll specialize and
// guard in this case.
return THPUtils_checkLong(obj) || THPUtils_checkString(obj) ||
THPDevice_Check(obj) || torch::is_symint(py::handle(obj));
THPDevice_Check(obj) || torch::is_symint(py::handle(obj)) ||
torch::is_dynint(py::handle(obj));
case ParameterType::STREAM:
return THPStream_Check(obj);
case ParameterType::STRING:
@ -1881,7 +1887,8 @@ at::Tensor PythonArgs::tensor_slow(int i) {
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
// because although Scalar supports SymInt/SymFloat, the subsequent
// conversion to Tensor does not. Instead, do it out of band.
} else if (torch::is_symint(py::handle(obj))) {
} else if (
torch::is_symint(py::handle(obj)) || torch::is_dynint(py::handle(obj))) {
save_symint = true;
// This scalar value doesn't matter, it shouldn't ever actually
// get read out. Make it a big and weird looking number to help
@ -1969,6 +1976,10 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
return at::Scalar(py::cast<c10::SymInt>(arg));
}
if (torch::is_dynint(arg)) {
return at::Scalar(py::cast<int>(arg));
}
if (torch::is_symfloat(arg)) {
return at::Scalar(py::cast<c10::SymFloat>(arg));
}

View File

@ -89,7 +89,7 @@ inline bool THPUtils_checkScalar(PyObject* obj) {
}
#endif
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
torch::is_symint(py::handle(obj)) ||
torch::is_symint(py::handle(obj)) || torch::is_dynint(py::handle(obj)) ||
torch::is_symfloat(py::handle(obj)) || torch::is_symbool(py::handle(obj));
}
@ -612,6 +612,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
try {
if (is_symint(py::handle(obj))) {
res.push_back(py::handle(obj).cast<c10::SymInt>());
} else if (is_dynint(py::handle(obj))) {
res.push_back(py::handle(obj).cast<int>());
} else {
res.emplace_back(THPUtils_unpackIndex(obj));
}
@ -640,6 +642,9 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(
size1,
py::handle(arg).cast<c10::SymInt>().guard_int(__FILE__, __LINE__));
}
if (size1 > 0 && torch::is_dynint(py::handle(arg))) {
return std::vector<int64_t>(size1, py::handle(arg).cast<int>());
}
auto tuple = PyTuple_Check(arg);
// NOLINTNEXTLINE(bugprone-branch-clone)
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
@ -672,6 +677,8 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(
} else if (torch::is_symint(py::handle(obj))) {
res[idx] = py::cast<c10::SymInt>(py::handle(obj))
.guard_int(__FILE__, __LINE__);
} else if (torch::is_dynint(py::handle(obj))) {
res[idx] = py::handle(obj).cast<int>();
} else if (THPVariable_Check(obj)) {
auto& var = THPVariable_Unpack(obj);
if (var.numel() != 1 ||
@ -846,6 +853,10 @@ inline at::Device toDevice(PyObject* obj) {
py::cast<c10::SymInt>(py::handle(obj)).guard_int(__FILE__, __LINE__);
return deviceFromLong(device_index);
}
if (torch::is_dynint(py::handle(obj))) {
auto device_index = py::cast<int>(py::handle(obj));
return deviceFromLong(device_index);
}
const std::string& device_str = THPUtils_unpackString(obj);
return at::Device(device_str);
}
@ -982,6 +993,9 @@ inline int64_t PythonArgs::toInt64(int i) {
return py::cast<c10::SymInt>(py::handle(args[i]))
.guard_int(__FILE__, __LINE__);
}
if (torch::is_dynint(py::handle(args[i]))) {
return py::cast<int>(py::handle(args[i]));
}
return THPUtils_unpackLong(args[i]);
}
@ -1055,6 +1069,9 @@ inline double PythonArgs::toDouble(int i) {
return static_cast<double>(py::cast<c10::SymInt>(py::handle(args[i]))
.guard_int(__FILE__, __LINE__));
}
if (torch::is_dynint(py::handle(args[i]))) {
return static_cast<double>(py::cast<int>(py::handle(args[i])));
}
return THPUtils_unpackDouble(args[i]);
}

View File

@ -53,4 +53,24 @@ py::handle get_symbool_class() {
#endif
}
py::handle get_dynint_class() {
// NB: leak
#if IS_PYBIND_2_13_PLUS
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
storage;
return storage
.call_once_and_store_result([]() -> py::object {
return py::module::import("torch.fx.experimental.sym_node")
.attr("DynamicInt");
})
.get_stored();
#else
static py::handle symbool_class =
py::object(py::module::import("torch.fx.experimental.sym_node")
.attr("DynamicInt"))
.release();
return symbool_class;
#endif
}
} // namespace torch

View File

@ -12,6 +12,7 @@ namespace torch {
TORCH_PYTHON_API py::handle get_symint_class();
TORCH_PYTHON_API py::handle get_symfloat_class();
TORCH_PYTHON_API py::handle get_symbool_class();
TORCH_PYTHON_API py::handle get_dynint_class();
// NB: These functions must not be called too early, otherwise torch not setup.
// Alternate design is to have torch "register" the object to us
@ -24,6 +25,9 @@ inline bool is_symfloat(py::handle obj) {
inline bool is_symbool(py::handle obj) {
return py::isinstance(obj, get_symbool_class());
}
inline bool is_dynint(py::handle obj) {
return py::isinstance(obj, get_dynint_class());
}
namespace impl {

View File

@ -49,7 +49,7 @@ log = logging.getLogger(__name__)
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
__all__ = ["SymNode", "method_to_operator", "magic_methods", "DynamicInt"]
from torch.types import py_sym_types as SymTypes
@ -625,6 +625,40 @@ class SymNode:
return False
class _DynamicScalar:
def __new__(cls, *args):
if cls is _DynamicScalar:
raise TypeError("_DynamicScalar is an abstract base class, use DynamicInt.")
return super().__new__(cls, *args)
class DynamicInt(_DynamicScalar, int):
"""
User API for marking dynamic integers in `torch.compile`.
Intended to be compatible with both compile and eager mode.
Example usage::
fn = torch.compile(f)
x = DynamicInt(4)
fn(x) # compiles x as a dynamic integer input; returns f(4)
"""
def __new__(cls, val):
assert isinstance(val, int)
obj = super().__new__(cls, int(val))
return obj
def __repr__(self):
return f"DynamicInt({self.real})"
def __floordiv__(self, other): # // was casting to int without these overrides?
return DynamicInt(self.real // other)
def __rfloordiv__(self, other):
return DynamicInt(other // self.real)
# TODO: this probably needs the sizes-strides eval functions
METHOD_TO_OPERATOR = {
"pos": operator.pos,
@ -1650,7 +1684,6 @@ for method, func in sizes_strides_methods.items():
def _make_user_magic(method, user_type):
# User magic takes care of wrapping the other operand into a node,
# so that our internal logic can assume everything is nodes
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"sym_{method}"
else:
@ -1781,7 +1814,7 @@ def _make_user_magic(method, user_type):
other = promote(other)
self, other = promote2(self, other)
if is_constant(self):
return (method_to_operator(method))(get_constant(self), other)
return (method_to_operator(method))(other, get_constant(self))
if is_constant(other):
other = get_constant(other)
other_node = to_node(self.node, other)
@ -1790,11 +1823,31 @@ def _make_user_magic(method, user_type):
ret = wrap_node(getattr(other_node, method_attr)(self.node))
return get_constant(ret) if is_constant(ret) else ret
def setattrs(user_type, attr, symnode_impl):
"""
Registers the SymNode magic method on SymInt/Float/Bool,
and optionally registers a corresponding wrapped method on DynamicInt.
"""
# SymInt/Float/Bool
setattr(user_type, attr, symnode_impl)
# DynamicInt impl
def dynamic_int_impl(*args):
args = [x.real if isinstance(x, DynamicInt) else x for x in args]
out = getattr(int, attr)(*args)
if isinstance(out, int) and not isinstance(out, bool):
return DynamicInt(out)
return out
if user_type is SymInt:
setattr(DynamicInt, attr, dynamic_int_impl)
if method in unary_magic_methods:
setattr(user_type, f"__{method}__", unary_magic_impl)
setattrs(user_type, f"__{method}__", unary_magic_impl)
elif method in unary_nonmagic_methods:
orig = getattr(user_type, method)
setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
setattrs(user_type, method, update_wrapper(unary_magic_impl, orig))
elif method == "sym_ite":
def sym_ite_magic_impl(pred, then_val, else_val):
@ -1811,7 +1864,7 @@ def _make_user_magic(method, user_type):
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
return get_constant(ret) if ret.node.is_constant() else ret
setattr(user_type, f"__{method}__", sym_ite_magic_impl)
setattrs(user_type, f"__{method}__", sym_ite_magic_impl)
elif method == "round":
def round_magic_impl(self, ndigits=None):
@ -1820,14 +1873,14 @@ def _make_user_magic(method, user_type):
return wrap_node(getattr(self.node, method)(ndigits))
setattr(user_type, f"__{method}__", round_magic_impl)
setattrs(user_type, f"__{method}__", round_magic_impl)
else:
method_name = method
if method in bitwise_ops:
method_name = bitwise_ops[method]
setattr(user_type, f"__{method_name}__", binary_magic_impl)
setattrs(user_type, f"__{method_name}__", binary_magic_impl)
if method in reflectable_magic_methods:
setattr(user_type, f"__r{method_name}__", rbinary_magic_impl)
setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
for method, func in magic_methods.items(): # type: ignore[assignment]