mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[dynamo] Guard serialization for FUNCTORCH_STACK_MATCH (#152616)
Make Functorch interpreters serializable most of the time, so that we can save the guards on functorch states. ## Test Cases: 0. torch.compile() without functorch layers present. Guard should fail with any layer being pushed. 1. torch.compile() nested in vmap. 2. torch.compile() nested in grad. 3. torch.compile() nested in jvp + vmap 4. torch.compile() nested functionalize 5. torch.compile() nested in vmap + grad Differential Revision: [D74008787](https://our.internmc.facebook.com/intern/diff/D74008787/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152616 Approved by: https://github.com/zou3519 ghstack dependencies: #152615
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d1cbcd8a3
commit
ffd58293f7
@ -376,6 +376,7 @@ cc_library(
|
||||
":torch_headers",
|
||||
"@fbgemm",
|
||||
"@ideep",
|
||||
"@nlohmann",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
@ -8,6 +8,8 @@
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
namespace at::functorch {
|
||||
|
||||
// NOTE: [functorch interpreter stack]
|
||||
@ -91,24 +93,95 @@ std::ostream& operator<<(std::ostream& os, const TransformType& t);
|
||||
struct VmapInterpreterMeta {
|
||||
explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
|
||||
batchSize_(std::move(batchSize)), randomness_(randomness) {}
|
||||
|
||||
c10::SymInt batchSize_;
|
||||
RandomnessType randomness_;
|
||||
|
||||
VmapInterpreterMeta() = default;
|
||||
VmapInterpreterMeta(const VmapInterpreterMeta&) = default;
|
||||
VmapInterpreterMeta(VmapInterpreterMeta&&) = default;
|
||||
VmapInterpreterMeta& operator=(const VmapInterpreterMeta&) = default;
|
||||
VmapInterpreterMeta& operator=(VmapInterpreterMeta&&) = default;
|
||||
~VmapInterpreterMeta() = default;
|
||||
|
||||
template <typename T>
|
||||
friend void to_json(T& json_j, const VmapInterpreterMeta& json_t) {
|
||||
if (json_t.batchSize_.is_heap_allocated()) {
|
||||
throw std::runtime_error("Serialization for heap-allocated SymInt is not implemented yet");
|
||||
}
|
||||
json_j["batchSize"] = json_t.batchSize_.as_int_unchecked();
|
||||
json_j["randomness"] = static_cast<int64_t>(json_t.randomness_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void from_json(const T& json_j, VmapInterpreterMeta& json_t) {
|
||||
json_t.batchSize_ = c10::SymInt(SymInt::Unchecked::UNCHECKED, json_j["batchSize"]);
|
||||
json_t.randomness_ = static_cast<RandomnessType>(json_j["randomness"]);
|
||||
}
|
||||
};
|
||||
|
||||
struct GradInterpreterMeta {
|
||||
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
|
||||
GradInterpreterMeta() = default;
|
||||
GradInterpreterMeta(const GradInterpreterMeta&) = default;
|
||||
GradInterpreterMeta(GradInterpreterMeta&&) = default;
|
||||
GradInterpreterMeta& operator=(const GradInterpreterMeta&) = default;
|
||||
GradInterpreterMeta& operator=(GradInterpreterMeta&&) = default;
|
||||
~GradInterpreterMeta() = default;
|
||||
|
||||
bool prevGradMode_;
|
||||
template <typename T>
|
||||
friend void to_json(T& json_j, const GradInterpreterMeta& json_t) {
|
||||
json_j["prevGradMode"] = json_t.prevGradMode_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void from_json(const T& json_j, GradInterpreterMeta& json_t) {
|
||||
json_t.prevGradMode_ = json_j["prevGradMode"];
|
||||
}
|
||||
};
|
||||
|
||||
struct JvpInterpreterMeta {
|
||||
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
|
||||
JvpInterpreterMeta() = default;
|
||||
JvpInterpreterMeta(const JvpInterpreterMeta&) = default;
|
||||
JvpInterpreterMeta(JvpInterpreterMeta&&) = default;
|
||||
JvpInterpreterMeta& operator=(const JvpInterpreterMeta&) = default;
|
||||
JvpInterpreterMeta& operator=(JvpInterpreterMeta&&) = default;
|
||||
~JvpInterpreterMeta() = default;
|
||||
|
||||
bool prevFwdGradMode_;
|
||||
template <typename T>
|
||||
friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) {
|
||||
json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) {
|
||||
json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"];
|
||||
}
|
||||
};
|
||||
|
||||
struct FunctionalizeInterpreterMeta {
|
||||
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
|
||||
functionalizeAddBackViews_(functionalizeAddBackViews) {}
|
||||
FunctionalizeInterpreterMeta() = default;
|
||||
FunctionalizeInterpreterMeta(const FunctionalizeInterpreterMeta&) = default;
|
||||
FunctionalizeInterpreterMeta(FunctionalizeInterpreterMeta&&) = default;
|
||||
FunctionalizeInterpreterMeta& operator=(const FunctionalizeInterpreterMeta&) = default;
|
||||
FunctionalizeInterpreterMeta& operator=(FunctionalizeInterpreterMeta&&) = default;
|
||||
~FunctionalizeInterpreterMeta() = default;
|
||||
|
||||
bool functionalizeAddBackViews_;
|
||||
template <typename T>
|
||||
friend void to_json(T& json_j, const FunctionalizeInterpreterMeta& json_t) {
|
||||
json_j["functionalizeAddBackViews"] = json_t.functionalizeAddBackViews_;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void from_json(const T& json_j, FunctionalizeInterpreterMeta& json_t) {
|
||||
json_t.functionalizeAddBackViews_ = json_j["functionalizeAddBackViews"];
|
||||
}
|
||||
};
|
||||
|
||||
typedef std::variant<
|
||||
@ -172,6 +245,75 @@ struct Interpreter {
|
||||
// Please don't use this
|
||||
explicit Interpreter() = default;
|
||||
|
||||
template <typename T>
|
||||
friend void to_json(T& json_j, const Interpreter& json_t) {
|
||||
json_j["type"] = static_cast<int64_t>(json_t.type_);
|
||||
json_j["level"] = json_t.level_;
|
||||
if (json_t.savedLocalDispatchKeySet_) {
|
||||
json_j["savedLocalDispatchKeySet"] = {
|
||||
{"included", json_t.savedLocalDispatchKeySet_->included_.raw_repr()},
|
||||
{"excluded", json_t.savedLocalDispatchKeySet_->excluded_.raw_repr()}
|
||||
};
|
||||
} else {
|
||||
json_j["savedLocalDispatchKeySet"] = nlohmann::json();
|
||||
}
|
||||
json_j["is_alive"] = *json_t.is_alive_;
|
||||
std::visit([&](auto&& arg) {
|
||||
using V = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<V, int64_t>) {
|
||||
json_j["meta"] = {{"Torch", arg}};
|
||||
} else if constexpr (std::is_same_v<V, GradInterpreterMeta>) {
|
||||
json_j["meta"] = {{"Grad", arg}};
|
||||
} else if constexpr (std::is_same_v<V, JvpInterpreterMeta>) {
|
||||
json_j["meta"] = {{"Jvp", arg}};
|
||||
} else if constexpr (std::is_same_v<V, VmapInterpreterMeta>) {
|
||||
json_j["meta"] = {{"Vmap", arg}};
|
||||
} else if constexpr (std::is_same_v<V, FunctionalizeInterpreterMeta>) {
|
||||
json_j["meta"] = {{"Functionalize", arg}};
|
||||
} else {
|
||||
static_assert(false && sizeof(V), "unknown variant case");
|
||||
}
|
||||
}, json_t.meta_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
friend void from_json(const T& json_j, Interpreter& json_t) {
|
||||
json_t.type_ = static_cast<TransformType>(json_j["type"]);
|
||||
json_t.level_ = json_j["level"];
|
||||
auto savedLocalDispatchKeySet = json_j["savedLocalDispatchKeySet"];
|
||||
if (savedLocalDispatchKeySet.is_null()) {
|
||||
json_t.savedLocalDispatchKeySet_ = std::nullopt;
|
||||
} else {
|
||||
c10::impl::PODLocalDispatchKeySet pod;
|
||||
pod.set_included(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["included"].template get<uint64_t>()));
|
||||
pod.set_excluded(DispatchKeySet::from_raw_repr(savedLocalDispatchKeySet["excluded"].template get<uint64_t>()));
|
||||
json_t.savedLocalDispatchKeySet_ = c10::impl::LocalDispatchKeySet(pod);
|
||||
}
|
||||
json_t.is_alive_ = std::make_shared<bool>(json_j["is_alive"]);
|
||||
auto meta = json_j["meta"];
|
||||
if (meta.contains("Torch")) {
|
||||
json_t.meta_.emplace<int64_t>(meta["Torch"].template get<int64_t>());
|
||||
} else if (meta.contains("Grad")) {
|
||||
json_t.meta_.emplace<GradInterpreterMeta>(meta["Grad"].template get<GradInterpreterMeta>());
|
||||
} else if (meta.contains("Jvp")) {
|
||||
json_t.meta_.emplace<JvpInterpreterMeta>(meta["Jvp"].template get<JvpInterpreterMeta>());
|
||||
} else if (meta.contains("Vmap")) {
|
||||
json_t.meta_.emplace<VmapInterpreterMeta>(meta["Vmap"].template get<VmapInterpreterMeta>());
|
||||
} else if (meta.contains("Functionalize")) {
|
||||
json_t.meta_.emplace<FunctionalizeInterpreterMeta>(meta["Functionalize"].template get<FunctionalizeInterpreterMeta>());
|
||||
} else {
|
||||
throw std::runtime_error("unknown interpreter metadata type");
|
||||
}
|
||||
}
|
||||
|
||||
std::string serialize() const {
|
||||
return nlohmann::json(*this).dump();
|
||||
}
|
||||
|
||||
static Interpreter deserialize(const std::string& serialized) {
|
||||
return nlohmann::json::parse(serialized).get<Interpreter>();
|
||||
}
|
||||
|
||||
private:
|
||||
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
|
||||
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
|
||||
|
@ -94,7 +94,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
fn.__closure__ or (),
|
||||
[], # TODO tf_mode_stack,
|
||||
code_options,
|
||||
lambda gm, *args, **kwargs: gm.forward,
|
||||
torch._dynamo.lookup_backend("eager"),
|
||||
one_graph=False,
|
||||
export=False,
|
||||
export_constraints=None,
|
||||
@ -326,6 +326,126 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
with torch.autograd.forward_ad.dual_level():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
def test_functorch_stack_match(self):
|
||||
# Test when functorch stack is empty.
|
||||
def fn(x):
|
||||
return torch.func.jvp(torch.sin, (x,), (x,))
|
||||
|
||||
x = torch.randn(3, 4)
|
||||
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
||||
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
with torch._functorch.vmap.vmap_increment_nesting(2, "error"):
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
def fn(x):
|
||||
def g(x):
|
||||
return torch.vmap(torch.func.grad(torch.sin))(x)
|
||||
|
||||
return torch.vmap(g)(x)
|
||||
|
||||
x = torch.randn(4, 5)
|
||||
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
# Test when there are more than 0 functorch layers.
|
||||
# Simulate the case where torch.compile is nested inside eager transforms.
|
||||
|
||||
# Case 1: vmap
|
||||
def fn(x):
|
||||
return x.sum()
|
||||
|
||||
ref = loaded = None
|
||||
|
||||
def run(x):
|
||||
nonlocal ref, loaded
|
||||
# Turn off automatic dynamic shape to so that functionalization
|
||||
# doesn't produce extra SymInt to serialize.
|
||||
with torch._dynamo.config.patch(automatic_dynamic_shapes=False):
|
||||
ref, loaded = self._test_serialization("FUNCTORCH_STACK_MATCH", fn, x)
|
||||
return fn(x)
|
||||
|
||||
torch.vmap(run)(x)
|
||||
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
# Case 2: grad
|
||||
x = torch.randn(3, 2)
|
||||
ref = loaded = None
|
||||
torch.func.grad(run)(x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
# Case 3: jvp + vmap
|
||||
x = torch.randn(3, 4)
|
||||
ref = loaded = None
|
||||
|
||||
def fn(x):
|
||||
return torch.func.jvp(torch.sin, (x,), (x,))
|
||||
|
||||
torch.func.jvp(torch.vmap(run), (x,), (x,))
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
||||
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
|
||||
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
# Case 4: functionalize
|
||||
x = torch.randn(3, 2)
|
||||
ref = loaded = None
|
||||
torch.func.functionalize(run)(x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
torch._C._functorch._func_increment_nesting(True)
|
||||
try:
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
finally:
|
||||
torch._C._functorch._func_decrement_nesting()
|
||||
|
||||
with torch._functorch.eager_transforms.jvp_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
# Case 5: vmap + grad
|
||||
def fn(x):
|
||||
return x.sum()
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
ref = loaded = None
|
||||
torch.vmap(torch.func.grad(run))(x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
|
||||
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
with torch._functorch.vmap.vmap_increment_nesting(1, "error"):
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
with torch._functorch.eager_transforms.grad_increment_nesting():
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -49,6 +49,9 @@ class RandomnessType(Enum):
|
||||
class CInterpreter:
|
||||
def key(self) -> TransformType: ...
|
||||
def level(self) -> int: ...
|
||||
def serialize(self) -> bytes: ...
|
||||
@staticmethod
|
||||
def deserialize(bytes) -> CInterpreter: ...
|
||||
|
||||
class CGradInterpreterPtr:
|
||||
def __init__(self, interpreter: CInterpreter) -> None: ...
|
||||
|
@ -475,7 +475,11 @@ def get_verbose_code_part(code_part: str, guard: Guard) -> str:
|
||||
extra = f" # {format_frame(fs, line=True)}"
|
||||
break
|
||||
elif guard.stack:
|
||||
extra = f" # {format_frame(guard.stack.summary()[-1])}"
|
||||
summary = guard.stack.summary()
|
||||
if len(summary) > 0:
|
||||
extra = f" # {format_frame(summary[-1])}"
|
||||
else:
|
||||
extra = " # <unknown>"
|
||||
return f"{code_part:<60}{extra}"
|
||||
|
||||
|
||||
@ -1591,7 +1595,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
def FUNCTORCH_STACK_MATCH(self, guard: Guard):
|
||||
# Invalidate functorch code if current level is different than
|
||||
# the one when FX graph was generated
|
||||
cis = torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters()
|
||||
cis = self.check_fn_manager.output_graph.functorch_layers
|
||||
states = [ci.get_state() for ci in cis]
|
||||
code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"]
|
||||
self._set_guard_export_info(guard, code)
|
||||
@ -2522,7 +2526,13 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
def _unpickle_dispatch_key_set(cls, raw_repr: int):
|
||||
return torch._C.DispatchKeySet.from_raw_repr(raw_repr)
|
||||
|
||||
@classmethod
|
||||
def _unpickle_functorch_interpreter(cls, json: bytes):
|
||||
return torch._C._functorch.CInterpreter.deserialize(json)
|
||||
|
||||
def reducer_override(self, obj):
|
||||
import sympy
|
||||
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
|
||||
return type(self)._unpickle_tensor, (
|
||||
torch.empty_like(obj, device="meta"),
|
||||
@ -2543,6 +2553,20 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
elif isinstance(obj, torch._C.DispatchKeySet):
|
||||
return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),)
|
||||
|
||||
elif isinstance(obj, torch._C._functorch.CInterpreter):
|
||||
return type(self)._unpickle_functorch_interpreter, (obj.serialize(),)
|
||||
|
||||
elif (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, sympy.Function)
|
||||
and hasattr(obj, "_torch_handler_name")
|
||||
):
|
||||
assert hasattr(obj, "_torch_unpickler")
|
||||
return obj._torch_unpickler, (obj._torch_handler_name,)
|
||||
|
||||
elif isinstance(obj, torch.SymInt):
|
||||
raise RuntimeError(f"Cannot serialize SymInt {obj} (node: {obj.node})")
|
||||
|
||||
if type(obj).__qualname__ != type(obj).__name__:
|
||||
raise RuntimeError(
|
||||
f"Type {type(obj)} for object {obj} cannot be saved "
|
||||
|
@ -301,6 +301,7 @@ class OutputGraphGuardsState:
|
||||
# Map from graph input's `Source` to sizes / strides metadata
|
||||
input_source_to_sizes_strides: dict[Source, dict[str, Any]]
|
||||
dual_level: int
|
||||
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
|
||||
|
||||
export: bool = False
|
||||
export_constraints: bool = False
|
||||
@ -354,6 +355,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
guard_on_key_order=set(),
|
||||
input_source_to_sizes_strides={},
|
||||
dual_level=torch.autograd.forward_ad._current_level,
|
||||
functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
|
||||
)
|
||||
self.tracers = [SubgraphTracer(self, is_export=export)]
|
||||
# Map from graph input's `Source` to its `VariableTracker` to
|
||||
@ -590,6 +592,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
guard_on_key_order=self.guard_on_key_order,
|
||||
input_source_to_sizes_strides=self.input_source_to_sizes_strides,
|
||||
dual_level=self.dual_level,
|
||||
functorch_layers=self.functorch_layers,
|
||||
export=self.export,
|
||||
export_constraints=self.export_constraints,
|
||||
_guards=self.guards,
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -79,6 +80,11 @@ class FuncTorchInterpreter(ABC):
|
||||
def check_state(self, state):
|
||||
return state == self.get_state()
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state.pop("_cptr", None)
|
||||
return state
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporarily_pop_interpreter_stack():
|
||||
@ -123,7 +129,10 @@ class VmapInterpreter(FuncTorchInterpreter):
|
||||
# cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
|
||||
# so that we can access methods specific to the vmap interpreter
|
||||
self._cdata = cdata
|
||||
self._cptr = CVmapInterpreterPtr(cdata)
|
||||
|
||||
@cached_property
|
||||
def _cptr(self):
|
||||
return CVmapInterpreterPtr(self._cdata)
|
||||
|
||||
def process(self, op, args, kwargs):
|
||||
kernel = op.functorch_table[TransformType.Vmap]
|
||||
@ -159,7 +168,10 @@ class GradInterpreter(FuncTorchInterpreter):
|
||||
assert cdata.key() == TransformType.Grad
|
||||
# See NOTE: [Interpreter cdata vs cptr]
|
||||
self._cdata = cdata
|
||||
self._cptr = CGradInterpreterPtr(cdata)
|
||||
|
||||
@cached_property
|
||||
def _cptr(self):
|
||||
return CGradInterpreterPtr(self._cdata)
|
||||
|
||||
def lift(self, args, kwargs):
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
@ -193,7 +205,10 @@ class JvpInterpreter(FuncTorchInterpreter):
|
||||
assert cdata.key() == TransformType.Jvp
|
||||
# See NOTE: [Interpreter cdata vs cptr]
|
||||
self._cdata = cdata
|
||||
self._cptr = CJvpInterpreterPtr(cdata)
|
||||
|
||||
@cached_property
|
||||
def _cptr(self):
|
||||
return CJvpInterpreterPtr(self._cdata)
|
||||
|
||||
def lift(self, args, kwargs):
|
||||
args, kwargs = pytree.tree_map_only(
|
||||
@ -226,7 +241,10 @@ class FunctionalizeInterpreter(FuncTorchInterpreter):
|
||||
def __init__(self, cdata: CInterpreter):
|
||||
assert cdata.key() == TransformType.Functionalize
|
||||
self._cdata = cdata
|
||||
self._cptr = CFunctionalizeInterpreterPtr(cdata)
|
||||
|
||||
@cached_property
|
||||
def _cptr(self):
|
||||
return CFunctionalizeInterpreterPtr(self._cdata)
|
||||
|
||||
def process(self, op, args, kwargs):
|
||||
kernel = op.functorch_table[TransformType.Functionalize]
|
||||
|
@ -575,7 +575,9 @@ void initFuncTorchBindings(PyObject* module) {
|
||||
.value("Different", RandomnessType::Different);
|
||||
py::class_<Interpreter>(m, "CInterpreter")
|
||||
.def("key", &Interpreter::key)
|
||||
.def("level", &Interpreter::level);
|
||||
.def("level", &Interpreter::level)
|
||||
.def("serialize", &Interpreter::serialize)
|
||||
.def_static("deserialize", &Interpreter::deserialize);
|
||||
py::class_<GradInterpreterPtr>(m, "CGradInterpreterPtr")
|
||||
.def(py::init<const Interpreter*>())
|
||||
.def("key", &GradInterpreterPtr::key)
|
||||
|
@ -1318,6 +1318,7 @@ def make_opaque_unary_fn(name):
|
||||
"""
|
||||
|
||||
_torch_handler_name = name
|
||||
_torch_unpickler = make_opaque_unary_fn
|
||||
|
||||
@classmethod
|
||||
def eval(cls, a):
|
||||
@ -1378,6 +1379,9 @@ def make_opaque_bitwise_fn(name, real_op_name):
|
||||
class BitwiseFn(sympy.Function):
|
||||
_torch_handler_name = name
|
||||
precedence: int = prec
|
||||
_torch_unpickler = functools.partial(
|
||||
make_opaque_bitwise_fn, real_op_name=real_op_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def eval(cls, a, b):
|
||||
|
Reference in New Issue
Block a user