[sigmoid] fix for FX tracing unflattened modules (#115708)

Differential Revision: [D52095387](https://our.internmc.facebook.com/intern/diff/D52095387/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115708
Approved by: https://github.com/zhxchen17
This commit is contained in:
suo
2023-12-12 22:44:05 -08:00
committed by PyTorch MergeBot
parent 75d3bbaaa2
commit 926236305f
2 changed files with 51 additions and 11 deletions

View File

@ -273,6 +273,21 @@ class TestFX(JitTestCase):
t = T()
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
def test_varargs_concrete(self):
class T(torch.nn.Module):
def forward(self, *args, **kwargs):
x = args[0] + args[1]
return x
args = (torch.rand(1), torch.rand(1))
t = T()
ref_outs = t(*args)
gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
gm.graph.lint()
test_outs = gm(*args)
self.assertEqual(ref_outs, test_outs)
def test_args_kwargs_no_self(self):
class T(torch.nn.Module):
def forward(*args, **kwargs): # noqa: B902

View File

@ -217,6 +217,17 @@ class PHWithMeta(PHBase):
self.ph_key = ph_key
def _transfer_attrs(fr, to):
for attr_name in dir(fr):
attr_val = getattr(fr, attr_name)
if (
not callable(attr_val)
and not attr_name.startswith("__")
and not hasattr(to, attr_name)
):
setattr(to, attr_name, attr_val)
@compatibility(is_backward_compatible=True)
class Tracer(TracerBase):
# Reference: https://github.com/pytorch/pytorch/issues/54354
@ -597,16 +608,6 @@ class Tracer(TracerBase):
"placeholder", f"{name}_{str(cnt)}", default, {}
)
if isinstance(x, PHBase):
def transfer_attrs(fr, to):
for attr_name in dir(fr):
attr_val = getattr(fr, attr_name)
if (
not callable(attr_val)
and not attr_name.startswith("__")
and not hasattr(to, attr_name)
):
setattr(to, attr_name, attr_val)
if x != PH:
# Transfer attrs in the case where you're using a placeholder other
# than the singleton PH (PH has no attributes to transfer).
@ -615,7 +616,7 @@ class Tracer(TracerBase):
# attributes set by the user) from the placeholder to the
# underlying nodes (the proxy is unwrapped by the user, but
# the metadata should hold).
transfer_attrs(fr=x, to=out.node)
_transfer_attrs(fr=x, to=out.node)
return out
# Union[int, bool] == bool in Python <= 3.6
@ -657,6 +658,30 @@ class Tracer(TracerBase):
type_expr=fn_for_analysis.__annotations__.get(name, None)
)
# This covers the very specific case where we are passing in flat
# concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
# In this case, just take the concrete_args and pass them through.
name_idx = 0
if isinstance(concrete_args, tuple) and \
len(concrete_args) > 0 and \
(co.co_flags & HAS_VARSTUFF) and \
total_args == 1:
for concrete_arg in concrete_args:
out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
if isinstance(concrete_arg, PHBase):
if concrete_arg != PH:
# Transfer attrs in the case where you're using a placeholder other
# than the singleton PH (PH has no attributes to transfer).
# Proxies were created out of the placeholders.
# Transfer any metadata (put on the placeholders in the form of
# attributes set by the user) from the placeholder to the
# underlying nodes (the proxy is unwrapped by the user, but
# the metadata should hold).
_transfer_attrs(fr=concrete_arg, to=out.node)
args.append(out)
name_idx += 1
return root_fn, args
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
if isinstance(concrete_args, tuple):
if len(arg_names) != len(concrete_args):