mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Initial prototype for dynamic int inputs, allows users to run with `torch.compile(f)(DynamicInt(4))`, compiling dynamically and using the underlying hint at runtime. Current behavior: - Also works in eager (mostly by subclassing int), as scalar input to torch functions, or numpy/math/etc. For example, `x = DynamicInt(3); torch.randn(x); torch.add(y, z, alpha=x); np.arange(x)` all act as if x = 3. - Behavior for arithmetic ops is to return new DynamicInts rather than static ints; `DynamicInt(3) * 2 = DynamicInt(6)`. This is via SymNode magic methods, but coverage might not be 100% - for example, I had to explicitly override floordiv to avoid int casting. This is not necessarily the case for non-magic method ops (e.g. `math.cos(x)`). The alternative here is to int cast on all operations, but I opted for this for dynamism propagation in non-compiled regions. - Doesn't ban fullgraph=False; DynamicInt objects might be leaked back to the user, but I guess this is fine, because they can be casted to ints when needed? - Dynamo only allocates one symbol per DynamicInt; specifying the same DynamicInt for multiple inputs leads to input deduplication, and a guard installed. - We don't raise on int specialization (in allowlist/maybe_mark_dynamic style) - but an easy change if needed. - DynamicInts as nn.Module attributes are handled. - We don't guard on the DynamicInt id, e.g. users can do the following without recompiling (maybe we should guard?) ```python x = DynamicInt(4) f(x) f(1) f(DynamicInt(3)) # same as f(3) ``` Follow-up work: - Specifying shape constraints, either at the int-level, e.g. ```python DynamicInt(64, name="s0", constraints=["s0 % 32 == 0", "s0 <= 1024"] ``` or at the compilation level, e.g. something like ```python s0 = DynamicInt(64, name="s0") s1 = DynamicInt(128, name="s1") with some_compiler_config.dynamic_int_constraints(["s1 == 2*s0", "s0 % 32 == 0"]): f(s0, s1) ``` This should subsume the need for specifying derived SymInts? - SymFloat support - currently it seems backed floats are specialized by the tensorify float pass, and there's no handling in inductor. - Propagating dynamism in tensor constructors, e.g. `x = DynamicInt(4); torch.randn(x)` could annotate `_dynamo_dynamic_indices`. Differential Revision: D81698719 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162194 Approved by: https://github.com/bobrenjc93
77 lines
2.0 KiB
C++
77 lines
2.0 KiB
C++
#include <torch/csrc/utils/python_symnode.h>
|
|
|
|
namespace torch {
|
|
|
|
py::handle get_symint_class() {
|
|
// NB: leak
|
|
#if IS_PYBIND_2_13_PLUS
|
|
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
|
|
storage;
|
|
return storage
|
|
.call_once_and_store_result([]() -> py::object {
|
|
return py::module::import("torch").attr("SymInt");
|
|
})
|
|
.get_stored();
|
|
#else
|
|
static py::handle symint_class =
|
|
py::object(py::module::import("torch").attr("SymInt")).release();
|
|
return symint_class;
|
|
#endif
|
|
}
|
|
|
|
py::handle get_symfloat_class() {
|
|
// NB: leak
|
|
#if IS_PYBIND_2_13_PLUS
|
|
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
|
|
storage;
|
|
return storage
|
|
.call_once_and_store_result([]() -> py::object {
|
|
return py::module::import("torch").attr("SymFloat");
|
|
})
|
|
.get_stored();
|
|
#else
|
|
static py::handle symfloat_class =
|
|
py::object(py::module::import("torch").attr("SymFloat")).release();
|
|
return symfloat_class;
|
|
#endif
|
|
}
|
|
|
|
py::handle get_symbool_class() {
|
|
// NB: leak
|
|
#if IS_PYBIND_2_13_PLUS
|
|
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
|
|
storage;
|
|
return storage
|
|
.call_once_and_store_result([]() -> py::object {
|
|
return py::module::import("torch").attr("SymBool");
|
|
})
|
|
.get_stored();
|
|
#else
|
|
static py::handle symbool_class =
|
|
py::object(py::module::import("torch").attr("SymBool")).release();
|
|
return symbool_class;
|
|
#endif
|
|
}
|
|
|
|
py::handle get_dynint_class() {
|
|
// NB: leak
|
|
#if IS_PYBIND_2_13_PLUS
|
|
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
|
|
storage;
|
|
return storage
|
|
.call_once_and_store_result([]() -> py::object {
|
|
return py::module::import("torch.fx.experimental.sym_node")
|
|
.attr("DynamicInt");
|
|
})
|
|
.get_stored();
|
|
#else
|
|
static py::handle symbool_class =
|
|
py::object(py::module::import("torch.fx.experimental.sym_node")
|
|
.attr("DynamicInt"))
|
|
.release();
|
|
return symbool_class;
|
|
#endif
|
|
}
|
|
|
|
} // namespace torch
|