Compare commits

...

1 Commits

Author SHA1 Message Date
fcb4f76fd0 init 2024-09-26 09:47:15 -07:00
4 changed files with 36 additions and 7 deletions

View File

@ -913,6 +913,22 @@ graph():
self.assertEqual(ep.module()(x, x), model(x, x))
self.assertEqual(ep.module()(x, y), model(x, y))
def test_pre_forward_hook(self):
class Foo(torch.nn.Module):
def forward(self, x):
return x + 2
def _refine(model, inputs):
x = inputs[0]
torch._check(x.shape[0] <= 32)
mod = Foo()
mod.register_forward_pre_hook(_refine)
inputs = (torch.randn(32, 64),)
ep = export(mod, inputs, strict=True)
print(ep)
def test_export_script_module(self):
class Foo(torch.nn.Module):
def forward(self, rv: torch.Tensor, t: torch.Tensor):

View File

@ -160,6 +160,7 @@ class OptimizedModule(torch.nn.Module):
}
def __init__(self, mod: torch.nn.Module, dynamo_ctx) -> None:
breakpoint()
super().__init__()
# Installs the params/buffer
self._orig_mod = mod
@ -169,6 +170,7 @@ class OptimizedModule(torch.nn.Module):
def _initialize(self):
# Do this stuff in constructor to lower overhead slightly
breakpoint()
if isinstance(self.dynamo_ctx, DisableContext):
# No need to check trace rules
self.forward = self.dynamo_ctx(self._orig_mod.__call__)
@ -1326,10 +1328,10 @@ def export(
def guard_export_print(guards: _guards.GuardsSet):
nonlocal out_guards
assert (
out_guards is None
), "whole graph export entails exactly one guard export"
out_guards = guards
if out_guards is None:
out_guards = guards
else:
out_guards.update(guards)
example_inputs = []
@ -1337,9 +1339,18 @@ def export(
gm: torch.fx.GraphModule, inner_example_inputs
):
nonlocal graph
assert (
graph is None
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
if graph is not None:
# if existing graph looks like a pre-hook with no outputs, then inline into current graph
output = next(iter(reversed(graph.graph.nodes)))
assert (
len(output.args[0]) == 0
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph."
ph_map = {}
for ph_old, ph_new in zip(graph.graph.nodes, gm.graph.nodes):
if ph_old.op == "placeholder":
ph_map[ph_old] = ph_new
with gm.graph.inserting_before(next(iter(gm.graph.nodes)).next):
gm.graph.graph_copy(graph.graph, ph_map)
graph = gm
nonlocal fake_mode, example_inputs

View File

@ -1329,6 +1329,7 @@ def _strict_export_lower_to_aten_ir(
for name in non_persistent_buffers
if name in reverse_name_lookup
}
with dynamo_fake_mode:
aten_export_artifact = lower_to_aten_callback(
gm_torch_level,

View File

@ -99,6 +99,7 @@ def _method_from_src(
globals_copy = globals.copy()
_exec_with_source(src, globals_copy, co_fields)
fn = globals_copy[method_name]
breakpoint()
del globals_copy[method_name]
return fn