[dynamo] Properly handle torch.script.jit under @staticmethod (#153984)

Fixes #153607.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153984
Approved by: https://github.com/williamwen42
This commit is contained in:
Ryan Guo
2025-05-20 14:32:16 -07:00
committed by PyTorch MergeBot
parent b184e3da9c
commit 4c6f0fe22f
2 changed files with 17 additions and 4 deletions

View File

@ -28,6 +28,19 @@ class InteropTests(torch._dynamo.test_case.TestCase):
trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
self._common(lambda a, b: trace_fn(a, b) + 1)
def test_staticmethod_script_fn(self):
class Foo:
@staticmethod
@torch.jit.script
def _g(a):
return a**2
def g(self, a, b):
return self._g(a) + b
foo = Foo()
self._common(lambda a, b: foo.g(a, b) + 1)
def test_vmap_in_graph(self):
from functools import wraps

View File

@ -1161,11 +1161,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
subobj.fget, self, source=source
).call_function(tx, [], {})
elif isinstance(subobj, staticmethod):
# Safe because `staticmethod.__get__` basically won't trigger user
# code and just returns the underlying `__func__`:
# https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100
func = subobj.__get__(self.value)
if source is not None:
return trace_rules.lookup(func).create_with_source(func, source=source)
else:
return trace_rules.lookup(func)(func)
return VariableTracker.build(tx, func, source)
elif isinstance(subobj, classmethod):
return variables.UserMethodVariable(
subobj.__func__, self.var_getattr(tx, "__class__"), source=source