mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Properly handle torch.script.jit
under @staticmethod
(#153984)
Fixes #153607. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153984 Approved by: https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
b184e3da9c
commit
4c6f0fe22f
@ -28,6 +28,19 @@ class InteropTests(torch._dynamo.test_case.TestCase):
|
||||
trace_fn = torch.jit.trace(fn, [torch.zeros(10), torch.zeros(10)])
|
||||
self._common(lambda a, b: trace_fn(a, b) + 1)
|
||||
|
||||
def test_staticmethod_script_fn(self):
|
||||
class Foo:
|
||||
@staticmethod
|
||||
@torch.jit.script
|
||||
def _g(a):
|
||||
return a**2
|
||||
|
||||
def g(self, a, b):
|
||||
return self._g(a) + b
|
||||
|
||||
foo = Foo()
|
||||
self._common(lambda a, b: foo.g(a, b) + 1)
|
||||
|
||||
def test_vmap_in_graph(self):
|
||||
from functools import wraps
|
||||
|
||||
|
@ -1161,11 +1161,11 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||
subobj.fget, self, source=source
|
||||
).call_function(tx, [], {})
|
||||
elif isinstance(subobj, staticmethod):
|
||||
# Safe because `staticmethod.__get__` basically won't trigger user
|
||||
# code and just returns the underlying `__func__`:
|
||||
# https://github.com/python/cpython/blob/3.11/Objects/funcobject.c#L1088-L1100
|
||||
func = subobj.__get__(self.value)
|
||||
if source is not None:
|
||||
return trace_rules.lookup(func).create_with_source(func, source=source)
|
||||
else:
|
||||
return trace_rules.lookup(func)(func)
|
||||
return VariableTracker.build(tx, func, source)
|
||||
elif isinstance(subobj, classmethod):
|
||||
return variables.UserMethodVariable(
|
||||
subobj.__func__, self.var_getattr(tx, "__class__"), source=source
|
||||
|
Reference in New Issue
Block a user