mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] support functools.partial forward (non-strict) (#153408)
Fixes #153086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/153408 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
40b719c97d
commit
8ac82c3e72
@ -3,6 +3,7 @@
|
|||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
@ -10244,6 +10245,30 @@ graph():
|
|||||||
ep = torch.export.export(mod, args, strict=False)
|
ep = torch.export.export(mod, args, strict=False)
|
||||||
self.assertTrue(torch.allclose(ep.module()(*args), mod(*args)))
|
self.assertTrue(torch.allclose(ep.module()(*args), mod(*args)))
|
||||||
|
|
||||||
|
def test_partial_patched_forward(self):
|
||||||
|
class Foo(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x + 2
|
||||||
|
|
||||||
|
def fancy_forward(x, y):
|
||||||
|
return x + 2 + y
|
||||||
|
|
||||||
|
Foo.forward = functools.partial(fancy_forward, y=torch.randn(4, 4))
|
||||||
|
x = torch.randn(4, 4)
|
||||||
|
# strict unsupported: "Tracing through optional input"
|
||||||
|
ep = export(Foo(), (x,), strict=False)
|
||||||
|
ep.module()(x)
|
||||||
|
|
||||||
|
class Bar(torch.nn.Module):
|
||||||
|
def forward(self, x, y, z):
|
||||||
|
return x + y + z
|
||||||
|
|
||||||
|
mod = Bar()
|
||||||
|
mod.forward = functools.partial(mod.forward, z=2)
|
||||||
|
mod.forward = functools.partial(mod.forward, y=torch.randn(4))
|
||||||
|
ep = export(mod, (x,), strict=False)
|
||||||
|
ep.module()(x)
|
||||||
|
|
||||||
@testing.expectedFailureCppRuntime
|
@testing.expectedFailureCppRuntime
|
||||||
def test_symint_input_basic(self):
|
def test_symint_input_basic(self):
|
||||||
strict = False # TODO: support strict=True
|
strict = False # TODO: support strict=True
|
||||||
|
|||||||
@ -295,7 +295,11 @@ def make_fake_inputs(
|
|||||||
# create another fake mode.
|
# create another fake mode.
|
||||||
fake_mode = context.fake_mode
|
fake_mode = context.fake_mode
|
||||||
elif not _is_torch_jit_trace:
|
elif not _is_torch_jit_trace:
|
||||||
code = nn_module.forward.__code__
|
if isinstance(nn_module.forward, functools.partial):
|
||||||
|
# functools handles nesting by itself, no need to recurse
|
||||||
|
code = nn_module.forward.func.__code__
|
||||||
|
else:
|
||||||
|
code = nn_module.forward.__code__
|
||||||
co_fields = {
|
co_fields = {
|
||||||
"co_name": code.co_name,
|
"co_name": code.co_name,
|
||||||
"co_filename": code.co_filename,
|
"co_filename": code.co_filename,
|
||||||
|
|||||||
Reference in New Issue
Block a user