mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] support functools.partial forward (non-strict) (#153408)
Fixes #153086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153408 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
40b719c97d
commit
8ac82c3e72
@ -3,6 +3,7 @@
|
||||
# flake8: noqa
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
@ -10244,6 +10245,30 @@ graph():
|
||||
ep = torch.export.export(mod, args, strict=False)
|
||||
self.assertTrue(torch.allclose(ep.module()(*args), mod(*args)))
|
||||
|
||||
def test_partial_patched_forward(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 2
|
||||
|
||||
def fancy_forward(x, y):
|
||||
return x + 2 + y
|
||||
|
||||
Foo.forward = functools.partial(fancy_forward, y=torch.randn(4, 4))
|
||||
x = torch.randn(4, 4)
|
||||
# strict unsupported: "Tracing through optional input"
|
||||
ep = export(Foo(), (x,), strict=False)
|
||||
ep.module()(x)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
return x + y + z
|
||||
|
||||
mod = Bar()
|
||||
mod.forward = functools.partial(mod.forward, z=2)
|
||||
mod.forward = functools.partial(mod.forward, y=torch.randn(4))
|
||||
ep = export(mod, (x,), strict=False)
|
||||
ep.module()(x)
|
||||
|
||||
@testing.expectedFailureCppRuntime
|
||||
def test_symint_input_basic(self):
|
||||
strict = False # TODO: support strict=True
|
||||
|
@ -295,7 +295,11 @@ def make_fake_inputs(
|
||||
# create another fake mode.
|
||||
fake_mode = context.fake_mode
|
||||
elif not _is_torch_jit_trace:
|
||||
code = nn_module.forward.__code__
|
||||
if isinstance(nn_module.forward, functools.partial):
|
||||
# functools handles nesting by itself, no need to recurse
|
||||
code = nn_module.forward.func.__code__
|
||||
else:
|
||||
code = nn_module.forward.__code__
|
||||
co_fields = {
|
||||
"co_name": code.co_name,
|
||||
"co_filename": code.co_filename,
|
||||
|
Reference in New Issue
Block a user