mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sigmoid] fix for FX tracing unflattened modules (#115708)
Differential Revision: [D52095387](https://our.internmc.facebook.com/intern/diff/D52095387/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/115708 Approved by: https://github.com/zhxchen17
This commit is contained in:
@ -273,6 +273,21 @@ class TestFX(JitTestCase):
|
||||
t = T()
|
||||
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
|
||||
|
||||
def test_varargs_concrete(self):
|
||||
class T(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
x = args[0] + args[1]
|
||||
return x
|
||||
|
||||
args = (torch.rand(1), torch.rand(1))
|
||||
|
||||
t = T()
|
||||
ref_outs = t(*args)
|
||||
gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
|
||||
gm.graph.lint()
|
||||
test_outs = gm(*args)
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_args_kwargs_no_self(self):
|
||||
class T(torch.nn.Module):
|
||||
def forward(*args, **kwargs): # noqa: B902
|
||||
|
@ -217,6 +217,17 @@ class PHWithMeta(PHBase):
|
||||
self.ph_key = ph_key
|
||||
|
||||
|
||||
def _transfer_attrs(fr, to):
|
||||
for attr_name in dir(fr):
|
||||
attr_val = getattr(fr, attr_name)
|
||||
if (
|
||||
not callable(attr_val)
|
||||
and not attr_name.startswith("__")
|
||||
and not hasattr(to, attr_name)
|
||||
):
|
||||
setattr(to, attr_name, attr_val)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class Tracer(TracerBase):
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/54354
|
||||
@ -597,16 +608,6 @@ class Tracer(TracerBase):
|
||||
"placeholder", f"{name}_{str(cnt)}", default, {}
|
||||
)
|
||||
if isinstance(x, PHBase):
|
||||
def transfer_attrs(fr, to):
|
||||
for attr_name in dir(fr):
|
||||
attr_val = getattr(fr, attr_name)
|
||||
if (
|
||||
not callable(attr_val)
|
||||
and not attr_name.startswith("__")
|
||||
and not hasattr(to, attr_name)
|
||||
):
|
||||
setattr(to, attr_name, attr_val)
|
||||
|
||||
if x != PH:
|
||||
# Transfer attrs in the case where you're using a placeholder other
|
||||
# than the singleton PH (PH has no attributes to transfer).
|
||||
@ -615,7 +616,7 @@ class Tracer(TracerBase):
|
||||
# attributes set by the user) from the placeholder to the
|
||||
# underlying nodes (the proxy is unwrapped by the user, but
|
||||
# the metadata should hold).
|
||||
transfer_attrs(fr=x, to=out.node)
|
||||
_transfer_attrs(fr=x, to=out.node)
|
||||
|
||||
return out
|
||||
# Union[int, bool] == bool in Python <= 3.6
|
||||
@ -657,6 +658,30 @@ class Tracer(TracerBase):
|
||||
type_expr=fn_for_analysis.__annotations__.get(name, None)
|
||||
)
|
||||
|
||||
# This covers the very specific case where we are passing in flat
|
||||
# concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
|
||||
# In this case, just take the concrete_args and pass them through.
|
||||
name_idx = 0
|
||||
if isinstance(concrete_args, tuple) and \
|
||||
len(concrete_args) > 0 and \
|
||||
(co.co_flags & HAS_VARSTUFF) and \
|
||||
total_args == 1:
|
||||
for concrete_arg in concrete_args:
|
||||
out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
|
||||
if isinstance(concrete_arg, PHBase):
|
||||
if concrete_arg != PH:
|
||||
# Transfer attrs in the case where you're using a placeholder other
|
||||
# than the singleton PH (PH has no attributes to transfer).
|
||||
# Proxies were created out of the placeholders.
|
||||
# Transfer any metadata (put on the placeholders in the form of
|
||||
# attributes set by the user) from the placeholder to the
|
||||
# underlying nodes (the proxy is unwrapped by the user, but
|
||||
# the metadata should hold).
|
||||
_transfer_attrs(fr=concrete_arg, to=out.node)
|
||||
args.append(out)
|
||||
name_idx += 1
|
||||
return root_fn, args
|
||||
|
||||
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
|
||||
if isinstance(concrete_args, tuple):
|
||||
if len(arg_names) != len(concrete_args):
|
||||
|
Reference in New Issue
Block a user