[torch] Use __prepare_scriptable__ for closures (#121553)

Summary:
This fixes a case left incomplete by https://github.com/pytorch/pytorch/pull/106229
The object is using __prepare_scriptable__ correctly inside of torch.jit.script()
but the clousre that is obtained below is using the non-prepared version.
This causes issues when the prepared and non-prepared versions are in different python modules.

Test Plan:
```
buck2 run mode/opt caffe2/test:jit -- -r test_decorator
```

Differential Revision: D54308741

Re-exporting, as #120806 #121307 were not properly merged.

Co-authored-by: Daniel Herrera <dherrera@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121553
Approved by: https://github.com/huydhn, https://github.com/seemethere
This commit is contained in:
Daniel Herrera
2024-03-11 19:14:15 +00:00
committed by PyTorch MergeBot
parent b4160fd9c7
commit dccc1ca839
5 changed files with 80 additions and 0 deletions

20
test/jit/mydecorator.py Normal file
View File

@ -0,0 +1,20 @@
r"""
Decorator used in test_decorator.py. We define it in a
separate file on purpose to test that the names in different modules
are resolved correctly.
"""
import functools
def my_decorator(func):
"""Dummy decorator that removes itself when torchscripting"""
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
return func(*args, **kwargs)
# torch.jit.script() uses __prepare_scriptable__ to remove the decorator
wrapped_func.__prepare_scriptable__ = lambda: func
return wrapped_func

13
test/jit/myfunction_a.py Normal file
View File

@ -0,0 +1,13 @@
"""
Helper function used in test_decorator.py. We define it in a
separate file on purpose to test that the names in different modules
are resolved correctly.
"""
from jit.mydecorator import my_decorator
from jit.myfunction_b import my_function_b
@my_decorator
def my_function_a(x: float) -> float:
return my_function_b(x) + 1

16
test/jit/myfunction_b.py Normal file
View File

@ -0,0 +1,16 @@
r"""
Helper function used in test_decorator.py. We define it in a
separate file on purpose to test that the names in different modules
are resolved correctly.
"""
from jit.mydecorator import my_decorator
@my_decorator
def my_function_b(x: float) -> float:
return my_function_c(x) + 2
def my_function_c(x: float) -> float:
return x + 3

View File

@ -0,0 +1,27 @@
# Owner(s): ["oncall: jit"]
# flake8: noqa
import sys
import unittest
from enum import Enum
from typing import List, Optional
import torch
from torch.testing._internal.jit_utils import JitTestCase
from jit.myfunction_a import my_function_a
class TestDecorator(JitTestCase):
def test_decorator(self):
# Note: JitTestCase.checkScript() does not work with decorators
# self.checkScript(my_function_a, (1.0,))
# Error:
# RuntimeError: expected def but found '@' here:
# @my_decorator
# ~ <--- HERE
# def my_function_a(x: float) -> float:
# Do a simple torch.jit.script() test instead
fn = my_function_a
fx = torch.jit.script(fn)
self.assertEqual(fn(1.0), fx(1.0))

View File

@ -996,6 +996,10 @@ def try_compile_fn(fn, loc):
f"Consider manually annotating `{fn}` with @torch.jit.script."
)
# The object returned by __prepare_scriptable__ might have a different closure.
# Resolve it here to get the right resolution callback.
fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator]
# We don't have the actual scope where the function was defined, but we can
# extract the necessary info from the closed over variables on the function
# object