[dynamo] Guard serialization for FUNCTORCH_STACK_MATCH (#152616)

Make Functorch interpreters serializable most of the time, so that we can save the guards on functorch states.

## Test Cases:

0. torch.compile() without functorch layers present. Guard should fail with any layer being pushed.
1. torch.compile() nested in vmap.
2. torch.compile() nested in grad.
3. torch.compile() nested in jvp + vmap
4. torch.compile() nested functionalize
5. torch.compile() nested in vmap + grad

Differential Revision: [D74008787](https://our.internmc.facebook.com/intern/diff/D74008787/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152616
Approved by: https://github.com/zou3519
ghstack dependencies: #152615
This commit is contained in:
zhxchen17
2025-05-02 07:31:56 -07:00
committed by PyTorch MergeBot
parent 1d1cbcd8a3
commit ffd58293f7
9 changed files with 325 additions and 8 deletions

View File

@ -376,6 +376,7 @@ cc_library(
":torch_headers",
"@fbgemm",
"@ideep",
"@nlohmann",
],
alwayslink = True,
)

View File

@ -8,6 +8,8 @@
#include <utility>
#include <variant>
#include <nlohmann/json.hpp>
namespace at::functorch {
// NOTE: [functorch interpreter stack]
@ -91,24 +93,95 @@ std::ostream& operator<<(std::ostream& os, const TransformType& t);
struct VmapInterpreterMeta {
explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
batchSize_(std::move(batchSize)), randomness_(randomness) {}
c10::SymInt batchSize_;
RandomnessType randomness_;
VmapInterpreterMeta() = default;
VmapInterpreterMeta(const VmapInterpreterMeta&) = default;
VmapInterpreterMeta(VmapInterpreterMeta&&) = default;
VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default;
VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default;
~VmapInterpreterMeta() = default;
template <typename T>
friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
if (json_t.batchSize_.is_heap_allocated()) {
throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet");
}
json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
}
template <typename T>
friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) {
json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]);
json_t.randomness_ = static_cast<RandomnessType>(json_j["randomness"]);
}
};
struct GradInterpreterMeta {
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
GradInterpreterMeta() = default;
GradInterpreterMeta(const GradInterpreterMeta&) = default;
GradInterpreterMeta(GradInterpreterMeta&&) = default;
GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default;
GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default;
~GradInterpreterMeta() = default;
bool prevGradMode_;
template <typename T>
friend void to_json(T& json_j, const GradInterpreterMeta& json_t) {
json_j["prevGradMode"] = json_t.prevGradMode_;
}
template <typename T>
friend void from_json(const T& json_j, GradInterpreterMeta& json_t) {
json_t.prevGradMode_ = json_j["prevGradMode"];
}
};
struct JvpInterpreterMeta {
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
JvpInterpreterMeta() = default;
JvpInterpreterMeta(const JvpInterpreterMeta&) = default;
JvpInterpreterMeta(JvpInterpreterMeta&&) = default;
JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default;
JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default;
~JvpInterpreterMeta() = default;
bool prevFwdGradMode_;
template <typename T>
friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) {
json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_;
}
template <typename T>
friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) {
json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"];
}
};
struct FunctionalizeInterpreterMeta {
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
functionalizeAddBackViews_(functionalizeAddBackViews) {}
FunctionalizeInterpreterMeta() = default;
FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default;
FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default;
FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default;
FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default;
~FunctionalizeInterpreterMeta() = default;
bool functionalizeAddBackViews_;
template <typename T>
friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) {
json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_;
}
template <typename T>
friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) {
json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"];
}
};
typedef std::variant<
@ -172,6 +245,75 @@ struct Interpreter {
// Please don't use this
explicit Interpreter() = default;
template <typename T>
friend void to_json(T& json_j, const Interpreter& json_t) {
json_j["type"] = static_cast<int64_t>(json_t.type_);
json_j["level"] = json_t.level_;
if (json_t.savedLocalDispatchKeySet_) {
json_j["savedLocalDispatchKeySet"] = {
{"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()},
{"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()}
};
} else {
json_j["savedLocalDispatchKeySet"] = nlohmann::json();
}
json_j["is_alive"] = *json_t.is_alive_;
std::visit([&](auto&& arg) {
using V = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<V, int64_t>) {
json_j["meta"] = {{"Torch", arg}};
} else if constexpr (std::is_same_v<V, GradInterpreterMeta>) {
json_j["meta"] = {{"Grad", arg}};
} else if constexpr (std::is_same_v<V, JvpInterpreterMeta>) {
json_j["meta"] = {{"Jvp", arg}};
} else if constexpr (std::is_same_v<V, VmapInterpreterMeta>) {
json_j["meta"] = {{"Vmap", arg}};
} else if constexpr (std::is_same_v<V, FunctionalizeInterpreterMeta>) {
json_j["meta"] = {{"Functionalize", arg}};
} else {
static_assert(false && sizeof(V), "unknown variant case");
}
}, json_t.meta_);
}
template <typename T>
friend void from_json(const T& json_j, Interpreter& json_t) {
json_t.type_ = static_cast<TransformType>(json_j["type"]);
json_t.level_ = json_j["level"];
auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"];
if (savedLocalDispatchKeySet.is_null()) {
json_t.savedLocalDispatchKeySet_ = std::nullopt;
} else {
c10::impl::PODLocalDispatchKeySet pod;
pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get<uint64_t>()));
pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get<uint64_t>()));
json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod);
}
json_t.is_alive_ = std::make_shared<bool>(json_j["is_alive"]);
auto meta = json_j["meta"];
if (meta.contains("Torch")) {
json_t.meta_.emplace<int64_t>(meta["Torch"].template get<int64_t>());
} else if (meta.contains("Grad")) {
json_t.meta_.emplace<GradInterpreterMeta>(meta["Grad"].template get<GradInterpreterMeta>());
} else if (meta.contains("Jvp")) {
json_t.meta_.emplace<JvpInterpreterMeta>(meta["Jvp"].template get<JvpInterpreterMeta>());
} else if (meta.contains("Vmap")) {
json_t.meta_.emplace<VmapInterpreterMeta>(meta["Vmap"].template get<VmapInterpreterMeta>());
} else if (meta.contains("Functionalize")) {
json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
} else {
throw std::runtime_error("unknown interpreter metadata type");
}
}
std::string serialize() const {
return nlohmann::json(*this).dump();
}
static Interpreter deserialize(const std::string& serialized) {
return nlohmann::json::parse(serialized).get<Interpreter>();
}
private:
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}

View File

@ -94,7 +94,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
fn.__closure__ or (),
[], # TODO tf_mode_stack,
code_options,
lambda gm, *args, **kwargs: gm.forward,
torch._dynamo.lookup_backend("eager"),
one_graph=False,
export=False,
export_constraints=None,
@ -326,6 +326,126 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
with torch.autograd.forward_ad.dual_level():
self._test_check_fn(ref, loaded, {"x": x}, False)
def test_functorch_stack_match(self):
# Test when functorch stack is empty.
def fn(x):
return torch.func.jvp(torch.sin, (x,), (x,))
x = torch.randn(3, 4)
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
with torch._functorch.vmap.vmap_increment_nesting(2, "error"):
self._test_check_fn(ref, loaded, {"x": x}, False)
def fn(x):
def g(x):
return torch.vmap(torch.func.grad(torch.sin))(x)
return torch.vmap(g)(x)
x = torch.randn(4, 5)
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
with torch._functorch.eager_transforms.grad_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, False)
# Test when there are more than 0 functorch layers.
# Simulate the case where torch.compile is nested inside eager transforms.
# Case 1: vmap
def fn(x):
return x.sum()
ref = loaded = None
def run(x):
nonlocal ref, loaded
# Turn off automatic dynamic shape to so that functionalization
# doesn't produce extra SymInt to serialize.
with torch._dynamo.config.patch(automatic_dynamic_shapes=False):
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
return fn(x)
torch.vmap(run)(x)
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
self._test_check_fn(ref, loaded, {"x": x}, True)
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch._functorch.eager_transforms.grad_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, False)
# Case 2: grad
x = torch.randn(3, 2)
ref = loaded = None
torch.func.grad(run)(x)
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch._functorch.eager_transforms.grad_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, True)
with torch._functorch.eager_transforms.grad_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
self._test_check_fn(ref, loaded, {"x": x}, False)
# Case 3: jvp + vmap
x = torch.randn(3, 4)
ref = loaded = None
def fn(x):
return torch.func.jvp(torch.sin, (x,), (x,))
torch.func.jvp(torch.vmap(run), (x,), (x,))
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch._functorch.eager_transforms.jvp_increment_nesting():
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
self._test_check_fn(ref, loaded, {"x": x}, True)
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
with torch._functorch.eager_transforms.jvp_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, False)
# Case 4: functionalize
x = torch.randn(3, 2)
ref = loaded = None
torch.func.functionalize(run)(x)
self._test_check_fn(ref, loaded, {"x": x}, False)
torch._C._functorch._func_increment_nesting(True)
try:
self._test_check_fn(ref, loaded, {"x": x}, True)
finally:
torch._C._functorch._func_decrement_nesting()
with torch._functorch.eager_transforms.jvp_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, False)
# Case 5: vmap + grad
def fn(x):
return x.sum()
x = torch.randn(3, 2)
ref = loaded = None
torch.vmap(torch.func.grad(run))(x)
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
with torch._functorch.eager_transforms.grad_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, True)
with torch._functorch.eager_transforms.grad_increment_nesting():
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
self._test_check_fn(ref, loaded, {"x": x}, False)
with torch._functorch.eager_transforms.grad_increment_nesting():
self._test_check_fn(ref, loaded, {"x": x}, False)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -49,6 +49,9 @@ class RandomnessType(Enum):
class CInterpreter:
def key(self) -> TransformType: ...
def level(self) -> int: ...
def serialize(self) -> bytes: ...
@staticmethod
def deserialize(bytes) -> CInterpreter: ...
class CGradInterpreterPtr:
def __init__(self, interpreter: CInterpreter) -> None: ...

View File

@ -475,7 +475,11 @@ def get_verbose_code_part(code_part: str, guard: Guard) -> str:
extra = f" # {format_frame(fs, line=True)}"
break
elif guard.stack:
extra = f" # {format_frame(guard.stack.summary()[-1])}"
summary = guard.stack.summary()
if len(summary) > 0:
extra = f" # {format_frame(summary[-1])}"
else:
extra = " # <unknown>"
return f"{code_part:<60}{extra}"
@ -1591,7 +1595,7 @@ class GuardBuilder(GuardBuilderBase):
def FUNCTORCH_STACK_MATCH(self, guard: Guard):
# Invalidate functorch code if current level is different than
# the one when FX graph was generated
cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters()
cis = self.check_fn_manager.output_graph.functorch_layers
states = [ci.get_state() for ci in cis]
code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"]
self._set_guard_export_info(guard, code)
@ -2522,7 +2526,13 @@ class GuardsStatePickler(pickle.Pickler):
def _unpickle_dispatch_key_set(cls, raw_repr: int):
return torch._C.DispatchKeySet.from_raw_repr(raw_repr)
@classmethod
def _unpickle_functorch_interpreter(cls, json: bytes):
return torch._C._functorch.CInterpreter.deserialize(json)
def reducer_override(self, obj):
import sympy
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
return type(self)._unpickle_tensor, (
torch.empty_like(obj, device="meta"),
@ -2543,6 +2553,20 @@ class GuardsStatePickler(pickle.Pickler):
elif isinstance(obj, torch._C.DispatchKeySet):
return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),)
elif isinstance(obj, torch._C._functorch.CInterpreter):
return type(self)._unpickle_functorch_interpreter, (obj.serialize(),)
elif (
inspect.isclass(obj)
and issubclass(obj, sympy.Function)
and hasattr(obj, "_torch_handler_name")
):
assert hasattr(obj, "_torch_unpickler")
return obj._torch_unpickler, (obj._torch_handler_name,)
elif isinstance(obj, torch.SymInt):
raise RuntimeError(f"Cannot serialize SymInt {obj} (node: {obj.node})")
if type(obj).__qualname__ != type(obj).__name__:
raise RuntimeError(
f"Type {type(obj)} for object {obj} cannot be saved "

View File

@ -301,6 +301,7 @@ class OutputGraphGuardsState:
# Map from graph input's `Source` to sizes / strides metadata
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
dual_level: int
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
export: bool = False
export_constraints: bool = False
@ -354,6 +355,7 @@ class OutputGraph(OutputGraphGuardsState):
guard_on_key_order=set(),
input_source_to_sizes_strides={},
dual_level=torch.autograd.forward_ad._current_level,
functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
)
self.tracers = [SubgraphTracer(self, is_export=export)]
# Map from graph input's `Source` to its `VariableTracker` to
@ -590,6 +592,7 @@ class OutputGraph(OutputGraphGuardsState):
guard_on_key_order=self.guard_on_key_order,
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
dual_level=self.dual_level,
functorch_layers=self.functorch_layers,
export=self.export,
export_constraints=self.export_constraints,
_guards=self.guards,

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import contextlib
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any
import torch
@ -79,6 +80,11 @@ class FuncTorchInterpreter(ABC):
def check_state(self, state):
return state == self.get_state()
def __getstate__(self):
state = self.__dict__.copy()
state.pop("_cptr", None)
return state
@contextlib.contextmanager
def temporarily_pop_interpreter_stack():
@ -123,7 +129,10 @@ class VmapInterpreter(FuncTorchInterpreter):
# cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
# so that we can access methods specific to the vmap interpreter
self._cdata = cdata
self._cptr = CVmapInterpreterPtr(cdata)
@cached_property
def _cptr(self):
return CVmapInterpreterPtr(self._cdata)
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Vmap]
@ -159,7 +168,10 @@ class GradInterpreter(FuncTorchInterpreter):
assert cdata.key() == TransformType.Grad
# See NOTE: [Interpreter cdata vs cptr]
self._cdata = cdata
self._cptr = CGradInterpreterPtr(cdata)
@cached_property
def _cptr(self):
return CGradInterpreterPtr(self._cdata)
def lift(self, args, kwargs):
args, kwargs = pytree.tree_map_only(
@ -193,7 +205,10 @@ class JvpInterpreter(FuncTorchInterpreter):
assert cdata.key() == TransformType.Jvp
# See NOTE: [Interpreter cdata vs cptr]
self._cdata = cdata
self._cptr = CJvpInterpreterPtr(cdata)
@cached_property
def _cptr(self):
return CJvpInterpreterPtr(self._cdata)
def lift(self, args, kwargs):
args, kwargs = pytree.tree_map_only(
@ -226,7 +241,10 @@ class FunctionalizeInterpreter(FuncTorchInterpreter):
def __init__(self, cdata: CInterpreter):
assert cdata.key() == TransformType.Functionalize
self._cdata = cdata
self._cptr = CFunctionalizeInterpreterPtr(cdata)
@cached_property
def _cptr(self):
return CFunctionalizeInterpreterPtr(self._cdata)
def process(self, op, args, kwargs):
kernel = op.functorch_table[TransformType.Functionalize]

View File

@ -575,7 +575,9 @@ void initFuncTorchBindings(PyObject* module) {
.value("Different", RandomnessType::Different);
py::class_<Interpreter>(m, "CInterpreter")
.def("key", &Interpreter::key)
.def("level", &Interpreter::level);
.def("level", &Interpreter::level)
.def("serialize", &Interpreter::serialize)
.def_static("deserialize", &Interpreter::deserialize);
py::class_<GradInterpreterPtr>(m, "CGradInterpreterPtr")
.def(py::init<const Interpreter*>())
.def("key", &GradInterpreterPtr::key)

View File

@ -1318,6 +1318,7 @@ def make_opaque_unary_fn(name):
"""
_torch_handler_name = name
_torch_unpickler = make_opaque_unary_fn
@classmethod
def eval(cls, a):
@ -1378,6 +1379,9 @@ def make_opaque_bitwise_fn(name, real_op_name):
class BitwiseFn(sympy.Function):
_torch_handler_name = name
precedence: int = prec
_torch_unpickler = functools.partial(
make_opaque_bitwise_fn, real_op_name=real_op_name
)
@classmethod
def eval(cls, a, b):