[export] support functools.partial forward (non-strict) (#153408)

Fixes #153086

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153408
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Pian Pawakapan
2025-05-13 23:30:10 +00:00
committed by PyTorch MergeBot
parent 40b719c97d
commit 8ac82c3e72
2 changed files with 30 additions and 1 deletions

View File

@ -3,6 +3,7 @@
# flake8: noqa
import copy
import dataclasses
import functools
import logging
import math
import operator
@ -10244,6 +10245,30 @@ graph():
ep = torch.export.export(mod, args, strict=False)
self.assertTrue(torch.allclose(ep.module()(*args), mod(*args)))
def test_partial_patched_forward(self):
class Foo(torch.nn.Module):
def forward(self, x):
return x + 2
def fancy_forward(x, y):
return x + 2 + y
Foo.forward = functools.partial(fancy_forward, y=torch.randn(4, 4))
x = torch.randn(4, 4)
# strict unsupported: "Tracing through optional input"
ep = export(Foo(), (x,), strict=False)
ep.module()(x)
class Bar(torch.nn.Module):
def forward(self, x, y, z):
return x + y + z
mod = Bar()
mod.forward = functools.partial(mod.forward, z=2)
mod.forward = functools.partial(mod.forward, y=torch.randn(4))
ep = export(mod, (x,), strict=False)
ep.module()(x)
@testing.expectedFailureCppRuntime
def test_symint_input_basic(self):
strict = False # TODO: support strict=True

View File

@ -295,7 +295,11 @@ def make_fake_inputs(
# create another fake mode.
fake_mode = context.fake_mode
elif not _is_torch_jit_trace:
code = nn_module.forward.__code__
if isinstance(nn_module.forward, functools.partial):
# functools handles nesting by itself, no need to recurse
code = nn_module.forward.func.__code__
else:
code = nn_module.forward.__code__
co_fields = {
"co_name": code.co_name,
"co_filename": code.co_filename,