mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamic shapes] DynamicInts prototype (#162194)
Initial prototype for dynamic int inputs, allows users to run with `torch.compile(f)(DynamicInt(4))`, compiling dynamically and using the underlying hint at runtime. Current behavior: - Also works in eager (mostly by subclassing int), as scalar input to torch functions, or numpy/math/etc. For example, `x = DynamicInt(3); torch.randn(x); torch.add(y, z, alpha=x); np.arange(x)` all act as if x = 3. - Behavior for arithmetic ops is to return new DynamicInts rather than static ints; `DynamicInt(3) * 2 = DynamicInt(6)`. This is via SymNode magic methods, but coverage might not be 100% - for example, I had to explicitly override floordiv to avoid int casting. This is not necessarily the case for non-magic method ops (e.g. `math.cos(x)`). The alternative here is to int cast on all operations, but I opted for this for dynamism propagation in non-compiled regions. - Doesn't ban fullgraph=False; DynamicInt objects might be leaked back to the user, but I guess this is fine, because they can be casted to ints when needed? - Dynamo only allocates one symbol per DynamicInt; specifying the same DynamicInt for multiple inputs leads to input deduplication, and a guard installed. - We don't raise on int specialization (in allowlist/maybe_mark_dynamic style) - but an easy change if needed. - DynamicInts as nn.Module attributes are handled. - We don't guard on the DynamicInt id, e.g. users can do the following without recompiling (maybe we should guard?) ```python x = DynamicInt(4) f(x) f(1) f(DynamicInt(3)) # same as f(3) ``` Follow-up work: - Specifying shape constraints, either at the int-level, e.g. ```python DynamicInt(64, name="s0", constraints=["s0 % 32 == 0", "s0 <= 1024"] ``` or at the compilation level, e.g. something like ```python s0 = DynamicInt(64, name="s0") s1 = DynamicInt(128, name="s1") with some_compiler_config.dynamic_int_constraints(["s1 == 2*s0", "s0 % 32 == 0"]): f(s0, s1) ``` This should subsume the need for specifying derived SymInts? - SymFloat support - currently it seems backed floats are specialized by the tensorify float pass, and there's no handling in inductor. - Propagating dynamism in tensor constructors, e.g. `x = DynamicInt(4); torch.randn(x)` could annotate `_dynamo_dynamic_indices`. Differential Revision: D81698719 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162194 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
f4eca0e3b3
commit
4c007073e6
@ -8,6 +8,10 @@
|
||||
These APIs are experimental and subject to change without notice.
|
||||
:::
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: torch.fx.experimental.sym_node.DynamicInt
|
||||
```
|
||||
|
||||
## torch.fx.experimental.symbolic_shapes
|
||||
|
||||
```{eval-rst}
|
||||
|
@ -1818,6 +1818,96 @@ class TestSymNumberMagicMethods(TestCase):
|
||||
self.assertTrue(isinstance(s3, int))
|
||||
self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
|
||||
|
||||
@fresh_cache()
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
@parametrize("backend", ["inductor", "eager"])
|
||||
def test_dynamic_int_basic_compile(self, backend):
|
||||
from torch.fx.experimental.sym_node import DynamicInt
|
||||
|
||||
cnt = CompileCounterWithBackend(backend)
|
||||
|
||||
# test scalar inputs to function
|
||||
def f(x, y, z):
|
||||
out = torch.tensor([x + y + z])
|
||||
out = out + torch.zeros(abs(x) + 2).sum() # test out tensor construction
|
||||
return out
|
||||
|
||||
fn = torch.compile(f, fullgraph=True, backend=cnt)
|
||||
x = DynamicInt(1)
|
||||
z = DynamicInt(3)
|
||||
self.assertEqual(fn(x, x, z), f(1, 1, 3)) # guard: x == y
|
||||
self.assertEqual(fn(2, 2, 0), f(2, 2, 0))
|
||||
self.assertEqual(fn(-1, -1, 2), f(-1, -1, 2))
|
||||
self.assertEqual(cnt.frame_count, 1) # no recompiles
|
||||
|
||||
self.assertEqual(fn(3, 4, 5), f(3, 4, 5)) # now we recompile
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
# test nn module property
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.i = DynamicInt(1)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.tensor([x + self.i])
|
||||
|
||||
cnt.clear()
|
||||
m = Foo()
|
||||
mc = torch.compile(m, backend=cnt, fullgraph=True)
|
||||
|
||||
self.assertEqual(mc(DynamicInt(0)), m(0))
|
||||
mc.i = -2 # override attribute
|
||||
self.assertEqual(mc(-1), m(-1))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_dynamic_int_eager_usage(self):
|
||||
from torch.fx.experimental.sym_node import DynamicInt
|
||||
|
||||
w = DynamicInt(-1)
|
||||
x = DynamicInt(0)
|
||||
y = DynamicInt(1)
|
||||
z = DynamicInt(2)
|
||||
|
||||
def check(l, r):
|
||||
self.assertTrue(isinstance(l, DynamicInt))
|
||||
self.assertEqual(l, r)
|
||||
|
||||
# test arithmetic
|
||||
check(2 * y + z, 4)
|
||||
check((10 - z) // 2, 4)
|
||||
check(1 // z, 0)
|
||||
check(-w + w**2, 2)
|
||||
check(x % z, 0)
|
||||
check(1 << z, 4)
|
||||
check(z | y, 3)
|
||||
check(min(y, z), 1)
|
||||
self.assertTrue(z > -2)
|
||||
with self.assertRaises(ZeroDivisionError):
|
||||
y % x
|
||||
|
||||
# math, numpy
|
||||
self.assertEqual(math.cos(x), y)
|
||||
self.assertEqual(math.prod([z, z], start=z), 8)
|
||||
self.assertEqual(np.arange(z)[y], 1)
|
||||
self.assertTrue(np.allclose(np.ones([y, z]).sum(axis=x), np.ones(z)))
|
||||
|
||||
# test conversions
|
||||
self.assertTrue(isinstance(x + 2, int))
|
||||
self.assertTrue(isinstance(x + 2, DynamicInt))
|
||||
self.assertEqual(y / 2.0, 0.5) # this could return DynamicFloat in future
|
||||
self.assertEqual(float(z), 2.0)
|
||||
self.assertFalse(bool(x))
|
||||
self.assertEqual(DynamicInt(x).real, x.real)
|
||||
|
||||
# torch functions, scalar inputs
|
||||
self.assertEqual(torch.arange(z)[:w][x], 0)
|
||||
self.assertEqual(torch.add(torch.tensor(w), torch.tensor(w), alpha=z), -3)
|
||||
self.assertEqual(
|
||||
list(torch.nn.Linear(z, y)(torch.randn(z * 2, z)).shape), [4, 1]
|
||||
)
|
||||
self.assertEqual(z * torch.ones(z).sum(dim=x), 4)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestSymNumberMagicMethods)
|
||||
|
||||
|
@ -136,6 +136,7 @@ from .source import (
|
||||
DefaultsSource,
|
||||
DictGetItemSource,
|
||||
DictSubclassGetItemSource,
|
||||
DynamicScalarSource,
|
||||
FlattenScriptObjectSource,
|
||||
FloatTensorSource,
|
||||
FSDPNNModuleSource,
|
||||
@ -1719,6 +1720,14 @@ class GuardBuilder(GuardBuilderBase):
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, DynamicScalarSource):
|
||||
assert base_guard_manager
|
||||
out = base_guard_manager.lambda_manager(
|
||||
python_lambda=lambda x: int(x),
|
||||
source=source_name,
|
||||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"missing guard manager builder {source} - {source.name()}"
|
||||
|
@ -2698,6 +2698,9 @@ class SubgraphTracer(fx.Tracer):
|
||||
# tracer is the current tracer that's readily accessible in current tracer's graph.
|
||||
self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}
|
||||
|
||||
# Maps _DynamicScalar object ids to allocated SymInt nodes, for symbol reuse
|
||||
self.dynamic_scalar_nodes: dict[int, torch.SymInt] = {}
|
||||
|
||||
self.prev_inst = None
|
||||
# True if this tracer is currently tracing into torch.utils.checkpoint
|
||||
# as part of speculate_subgraph.
|
||||
|
@ -526,6 +526,29 @@ class ConvertIntSource(ChainedSource):
|
||||
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DynamicScalarSource(ChainedSource):
|
||||
is_int: bool
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.base is not None
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# Integer casting at reconstruction helps reduce the amount of DynamicInts returned
|
||||
# to the user, in favor of plain ints.
|
||||
# For example, a compiled region that only does int arithmetic could return a
|
||||
# DynamicInt without the casting here.
|
||||
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "int"))
|
||||
codegen(self.base)
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def guard_source(self) -> GuardSource:
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self) -> str:
|
||||
return f"int({self.base.name()})"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FlattenScriptObjectSource(ChainedSource):
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -60,6 +60,7 @@ from torch._subclasses.meta_utils import is_sparse_any, safe_grad
|
||||
from torch._utils_internal import justknobs_check
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental._dynamism import normalize_source_name
|
||||
from torch.fx.experimental.sym_node import _DynamicScalar, DynamicInt
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_constrain_range_for_size,
|
||||
_nested_int_aware_sort,
|
||||
@ -101,6 +102,7 @@ from ..source import (
|
||||
ConvertIntSource,
|
||||
DictGetItemSource,
|
||||
DictSubclassGetItemSource,
|
||||
DynamicScalarSource,
|
||||
FloatTensorSource,
|
||||
GetItemSource,
|
||||
GradSource,
|
||||
@ -456,7 +458,9 @@ class VariableBuilder:
|
||||
# should NOT track them. If we use a single SymNodeVariable instance to track them
|
||||
# across multiple uses, then guards created for one usage will incorrectly apply to
|
||||
# all other usages of that constant, leading to unnecessary recompilations.
|
||||
return is_torch_sym(value) and isinstance(vt, SymNodeVariable)
|
||||
return (
|
||||
is_torch_sym(value) or isinstance(value, _DynamicScalar)
|
||||
) and isinstance(vt, SymNodeVariable)
|
||||
|
||||
if (
|
||||
(
|
||||
@ -1103,6 +1107,46 @@ class VariableBuilder:
|
||||
):
|
||||
self.install_guards(GuardBuilder.FUNCTION_MATCH)
|
||||
return ItertoolsVariable(value, source=self.source)
|
||||
elif isinstance(value, _DynamicScalar):
|
||||
is_int = isinstance(value, DynamicInt)
|
||||
source = DynamicScalarSource(self.source, is_int)
|
||||
if id(value) in self.tx.output.root_tracer.dynamic_scalar_nodes:
|
||||
# If we've already seen this dynamic scalar, reuse the existing
|
||||
# SymInt/SymFloat node.
|
||||
node = self.tx.output.root_tracer.dynamic_scalar_nodes[id(value)]
|
||||
else:
|
||||
sym = self.tx.output.shape_env.create_unspecified_symbol(
|
||||
value.real,
|
||||
source=source,
|
||||
dynamic_dim=DimDynamic.DYNAMIC,
|
||||
)
|
||||
node = self.tx.output.shape_env.create_symintnode(
|
||||
sym,
|
||||
hint=value.real,
|
||||
source=source,
|
||||
)
|
||||
|
||||
# Bind to graph input
|
||||
sym_node_proxy = self.tx.output.root_tracer.create_graph_input(
|
||||
re.sub(r"[^a-zA-Z0-9]+", "_", self.name),
|
||||
type(node),
|
||||
node,
|
||||
source=source,
|
||||
)
|
||||
sym_node_proxy.node.meta["grapharg"] = GraphArg(
|
||||
source,
|
||||
node,
|
||||
False,
|
||||
None,
|
||||
is_tensor=False,
|
||||
example_strong_ref=node,
|
||||
)
|
||||
sym_expr = node.node.expr
|
||||
assert isinstance(sym_expr, sympy.Symbol), (
|
||||
f"{sym_expr} is not a basic Symbol."
|
||||
)
|
||||
self.tx.output.tracked_fakes.append(TrackedFake(node, source, None))
|
||||
return SymNodeVariable(sym_node_proxy, node)
|
||||
elif is_torch_sym(value):
|
||||
# Note: this doesn't handle nested symints.
|
||||
# For SymBool input, we reuse the infra for SymInt by simulating SymBool with a SymInt in dynamo.
|
||||
|
@ -936,6 +936,9 @@ static bool is_int_or_symint(PyObject* obj) {
|
||||
if (torch::is_symint(py::handle(obj))) {
|
||||
return true;
|
||||
}
|
||||
if (torch::is_dynint(py::handle(obj))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// FakeTensor(..., size=()) is qualified for SymInt param,
|
||||
// but we can't go via __index__ (below) as we would normally
|
||||
@ -1070,7 +1073,8 @@ auto FunctionParameter::_check(
|
||||
return !var.requires_grad() && var.dim() == 0;
|
||||
}
|
||||
if (torch::is_symfloat(py::handle(obj)) ||
|
||||
torch::is_symint(py::handle(obj))) {
|
||||
torch::is_symint(py::handle(obj)) ||
|
||||
torch::is_dynint(py::handle(obj))) {
|
||||
// This will induce a guard
|
||||
return true;
|
||||
}
|
||||
@ -1085,7 +1089,8 @@ auto FunctionParameter::_check(
|
||||
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) &&
|
||||
!var.requires_grad() && var.dim() == 0;
|
||||
}
|
||||
if (torch::is_symint(py::handle(obj))) {
|
||||
if (torch::is_symint(py::handle(obj)) ||
|
||||
torch::is_dynint(py::handle(obj))) {
|
||||
// This will induce a guard
|
||||
return true;
|
||||
}
|
||||
@ -1127,7 +1132,8 @@ auto FunctionParameter::_check(
|
||||
// Allow symint to be passed in as device, but we'll specialize and
|
||||
// guard in this case.
|
||||
return THPUtils_checkLong(obj) || THPUtils_checkString(obj) ||
|
||||
THPDevice_Check(obj) || torch::is_symint(py::handle(obj));
|
||||
THPDevice_Check(obj) || torch::is_symint(py::handle(obj)) ||
|
||||
torch::is_dynint(py::handle(obj));
|
||||
case ParameterType::STREAM:
|
||||
return THPStream_Check(obj);
|
||||
case ParameterType::STRING:
|
||||
@ -1881,7 +1887,8 @@ at::Tensor PythonArgs::tensor_slow(int i) {
|
||||
// NB: we DO NOT put symbolic ints/floats into the Scalar itself,
|
||||
// because although Scalar supports SymInt/SymFloat, the subsequent
|
||||
// conversion to Tensor does not. Instead, do it out of band.
|
||||
} else if (torch::is_symint(py::handle(obj))) {
|
||||
} else if (
|
||||
torch::is_symint(py::handle(obj)) || torch::is_dynint(py::handle(obj))) {
|
||||
save_symint = true;
|
||||
// This scalar value doesn't matter, it shouldn't ever actually
|
||||
// get read out. Make it a big and weird looking number to help
|
||||
@ -1969,6 +1976,10 @@ at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
|
||||
return at::Scalar(py::cast<c10::SymInt>(arg));
|
||||
}
|
||||
|
||||
if (torch::is_dynint(arg)) {
|
||||
return at::Scalar(py::cast<int>(arg));
|
||||
}
|
||||
|
||||
if (torch::is_symfloat(arg)) {
|
||||
return at::Scalar(py::cast<c10::SymFloat>(arg));
|
||||
}
|
||||
|
@ -89,7 +89,7 @@ inline bool THPUtils_checkScalar(PyObject* obj) {
|
||||
}
|
||||
#endif
|
||||
return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
|
||||
torch::is_symint(py::handle(obj)) ||
|
||||
torch::is_symint(py::handle(obj)) || torch::is_dynint(py::handle(obj)) ||
|
||||
torch::is_symfloat(py::handle(obj)) || torch::is_symbool(py::handle(obj));
|
||||
}
|
||||
|
||||
@ -612,6 +612,8 @@ inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
|
||||
try {
|
||||
if (is_symint(py::handle(obj))) {
|
||||
res.push_back(py::handle(obj).cast<c10::SymInt>());
|
||||
} else if (is_dynint(py::handle(obj))) {
|
||||
res.push_back(py::handle(obj).cast<int>());
|
||||
} else {
|
||||
res.emplace_back(THPUtils_unpackIndex(obj));
|
||||
}
|
||||
@ -640,6 +642,9 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(
|
||||
size1,
|
||||
py::handle(arg).cast<c10::SymInt>().guard_int(__FILE__, __LINE__));
|
||||
}
|
||||
if (size1 > 0 && torch::is_dynint(py::handle(arg))) {
|
||||
return std::vector<int64_t>(size1, py::handle(arg).cast<int>());
|
||||
}
|
||||
auto tuple = PyTuple_Check(arg);
|
||||
// NOLINTNEXTLINE(bugprone-branch-clone)
|
||||
const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
||||
@ -672,6 +677,8 @@ inline std::vector<int64_t> PythonArgs::intlistWithDefault(
|
||||
} else if (torch::is_symint(py::handle(obj))) {
|
||||
res[idx] = py::cast<c10::SymInt>(py::handle(obj))
|
||||
.guard_int(__FILE__, __LINE__);
|
||||
} else if (torch::is_dynint(py::handle(obj))) {
|
||||
res[idx] = py::handle(obj).cast<int>();
|
||||
} else if (THPVariable_Check(obj)) {
|
||||
auto& var = THPVariable_Unpack(obj);
|
||||
if (var.numel() != 1 ||
|
||||
@ -846,6 +853,10 @@ inline at::Device toDevice(PyObject* obj) {
|
||||
py::cast<c10::SymInt>(py::handle(obj)).guard_int(__FILE__, __LINE__);
|
||||
return deviceFromLong(device_index);
|
||||
}
|
||||
if (torch::is_dynint(py::handle(obj))) {
|
||||
auto device_index = py::cast<int>(py::handle(obj));
|
||||
return deviceFromLong(device_index);
|
||||
}
|
||||
const std::string& device_str = THPUtils_unpackString(obj);
|
||||
return at::Device(device_str);
|
||||
}
|
||||
@ -982,6 +993,9 @@ inline int64_t PythonArgs::toInt64(int i) {
|
||||
return py::cast<c10::SymInt>(py::handle(args[i]))
|
||||
.guard_int(__FILE__, __LINE__);
|
||||
}
|
||||
if (torch::is_dynint(py::handle(args[i]))) {
|
||||
return py::cast<int>(py::handle(args[i]));
|
||||
}
|
||||
return THPUtils_unpackLong(args[i]);
|
||||
}
|
||||
|
||||
@ -1055,6 +1069,9 @@ inline double PythonArgs::toDouble(int i) {
|
||||
return static_cast<double>(py::cast<c10::SymInt>(py::handle(args[i]))
|
||||
.guard_int(__FILE__, __LINE__));
|
||||
}
|
||||
if (torch::is_dynint(py::handle(args[i]))) {
|
||||
return static_cast<double>(py::cast<int>(py::handle(args[i])));
|
||||
}
|
||||
return THPUtils_unpackDouble(args[i]);
|
||||
}
|
||||
|
||||
|
@ -53,4 +53,24 @@ py::handle get_symbool_class() {
|
||||
#endif
|
||||
}
|
||||
|
||||
py::handle get_dynint_class() {
|
||||
// NB: leak
|
||||
#if IS_PYBIND_2_13_PLUS
|
||||
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
|
||||
storage;
|
||||
return storage
|
||||
.call_once_and_store_result([]() -> py::object {
|
||||
return py::module::import("torch.fx.experimental.sym_node")
|
||||
.attr("DynamicInt");
|
||||
})
|
||||
.get_stored();
|
||||
#else
|
||||
static py::handle symbool_class =
|
||||
py::object(py::module::import("torch.fx.experimental.sym_node")
|
||||
.attr("DynamicInt"))
|
||||
.release();
|
||||
return symbool_class;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace torch
|
||||
|
@ -12,6 +12,7 @@ namespace torch {
|
||||
TORCH_PYTHON_API py::handle get_symint_class();
|
||||
TORCH_PYTHON_API py::handle get_symfloat_class();
|
||||
TORCH_PYTHON_API py::handle get_symbool_class();
|
||||
TORCH_PYTHON_API py::handle get_dynint_class();
|
||||
|
||||
// NB: These functions must not be called too early, otherwise torch not setup.
|
||||
// Alternate design is to have torch "register" the object to us
|
||||
@ -24,6 +25,9 @@ inline bool is_symfloat(py::handle obj) {
|
||||
inline bool is_symbool(py::handle obj) {
|
||||
return py::isinstance(obj, get_symbool_class());
|
||||
}
|
||||
inline bool is_dynint(py::handle obj) {
|
||||
return py::isinstance(obj, get_dynint_class());
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
|
@ -49,7 +49,7 @@ log = logging.getLogger(__name__)
|
||||
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
|
||||
|
||||
|
||||
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
|
||||
__all__ = ["SymNode", "method_to_operator", "magic_methods", "DynamicInt"]
|
||||
|
||||
|
||||
from torch.types import py_sym_types as SymTypes
|
||||
@ -625,6 +625,40 @@ class SymNode:
|
||||
return False
|
||||
|
||||
|
||||
class _DynamicScalar:
|
||||
def __new__(cls, *args):
|
||||
if cls is _DynamicScalar:
|
||||
raise TypeError("_DynamicScalar is an abstract base class, use DynamicInt.")
|
||||
return super().__new__(cls, *args)
|
||||
|
||||
|
||||
class DynamicInt(_DynamicScalar, int):
|
||||
"""
|
||||
User API for marking dynamic integers in `torch.compile`.
|
||||
Intended to be compatible with both compile and eager mode.
|
||||
|
||||
Example usage::
|
||||
|
||||
fn = torch.compile(f)
|
||||
x = DynamicInt(4)
|
||||
fn(x) # compiles x as a dynamic integer input; returns f(4)
|
||||
"""
|
||||
|
||||
def __new__(cls, val):
|
||||
assert isinstance(val, int)
|
||||
obj = super().__new__(cls, int(val))
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return f"DynamicInt({self.real})"
|
||||
|
||||
def __floordiv__(self, other): # // was casting to int without these overrides?
|
||||
return DynamicInt(self.real // other)
|
||||
|
||||
def __rfloordiv__(self, other):
|
||||
return DynamicInt(other // self.real)
|
||||
|
||||
|
||||
# TODO: this probably needs the sizes-strides eval functions
|
||||
METHOD_TO_OPERATOR = {
|
||||
"pos": operator.pos,
|
||||
@ -1650,7 +1684,6 @@ for method, func in sizes_strides_methods.items():
|
||||
def _make_user_magic(method, user_type):
|
||||
# User magic takes care of wrapping the other operand into a node,
|
||||
# so that our internal logic can assume everything is nodes
|
||||
|
||||
if method in magic_methods_on_operator_with_trailing_underscore:
|
||||
method_attr = f"sym_{method}"
|
||||
else:
|
||||
@ -1781,7 +1814,7 @@ def _make_user_magic(method, user_type):
|
||||
other = promote(other)
|
||||
self, other = promote2(self, other)
|
||||
if is_constant(self):
|
||||
return (method_to_operator(method))(get_constant(self), other)
|
||||
return (method_to_operator(method))(other, get_constant(self))
|
||||
if is_constant(other):
|
||||
other = get_constant(other)
|
||||
other_node = to_node(self.node, other)
|
||||
@ -1790,11 +1823,31 @@ def _make_user_magic(method, user_type):
|
||||
ret = wrap_node(getattr(other_node, method_attr)(self.node))
|
||||
return get_constant(ret) if is_constant(ret) else ret
|
||||
|
||||
def setattrs(user_type, attr, symnode_impl):
|
||||
"""
|
||||
Registers the SymNode magic method on SymInt/Float/Bool,
|
||||
and optionally registers a corresponding wrapped method on DynamicInt.
|
||||
"""
|
||||
|
||||
# SymInt/Float/Bool
|
||||
setattr(user_type, attr, symnode_impl)
|
||||
|
||||
# DynamicInt impl
|
||||
def dynamic_int_impl(*args):
|
||||
args = [x.real if isinstance(x, DynamicInt) else x for x in args]
|
||||
out = getattr(int, attr)(*args)
|
||||
if isinstance(out, int) and not isinstance(out, bool):
|
||||
return DynamicInt(out)
|
||||
return out
|
||||
|
||||
if user_type is SymInt:
|
||||
setattr(DynamicInt, attr, dynamic_int_impl)
|
||||
|
||||
if method in unary_magic_methods:
|
||||
setattr(user_type, f"__{method}__", unary_magic_impl)
|
||||
setattrs(user_type, f"__{method}__", unary_magic_impl)
|
||||
elif method in unary_nonmagic_methods:
|
||||
orig = getattr(user_type, method)
|
||||
setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
|
||||
setattrs(user_type, method, update_wrapper(unary_magic_impl, orig))
|
||||
elif method == "sym_ite":
|
||||
|
||||
def sym_ite_magic_impl(pred, then_val, else_val):
|
||||
@ -1811,7 +1864,7 @@ def _make_user_magic(method, user_type):
|
||||
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
|
||||
return get_constant(ret) if ret.node.is_constant() else ret
|
||||
|
||||
setattr(user_type, f"__{method}__", sym_ite_magic_impl)
|
||||
setattrs(user_type, f"__{method}__", sym_ite_magic_impl)
|
||||
elif method == "round":
|
||||
|
||||
def round_magic_impl(self, ndigits=None):
|
||||
@ -1820,14 +1873,14 @@ def _make_user_magic(method, user_type):
|
||||
|
||||
return wrap_node(getattr(self.node, method)(ndigits))
|
||||
|
||||
setattr(user_type, f"__{method}__", round_magic_impl)
|
||||
setattrs(user_type, f"__{method}__", round_magic_impl)
|
||||
else:
|
||||
method_name = method
|
||||
if method in bitwise_ops:
|
||||
method_name = bitwise_ops[method]
|
||||
setattr(user_type, f"__{method_name}__", binary_magic_impl)
|
||||
setattrs(user_type, f"__{method_name}__", binary_magic_impl)
|
||||
if method in reflectable_magic_methods:
|
||||
setattr(user_type, f"__r{method_name}__", rbinary_magic_impl)
|
||||
setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
|
||||
|
||||
|
||||
for method, func in magic_methods.items(): # type: ignore[assignment]
|
||||
|
Reference in New Issue
Block a user