mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch] Use __prepare_scriptable__ for closures (#121553)
Summary: This fixes a case left incomplete by https://github.com/pytorch/pytorch/pull/106229 The object is using __prepare_scriptable__ correctly inside of torch.jit.script() but the clousre that is obtained below is using the non-prepared version. This causes issues when the prepared and non-prepared versions are in different python modules. Test Plan: ``` buck2 run mode/opt caffe2/test:jit -- -r test_decorator ``` Differential Revision: D54308741 Re-exporting, as #120806 #121307 were not properly merged. Co-authored-by: Daniel Herrera <dherrera@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121553 Approved by: https://github.com/huydhn, https://github.com/seemethere
This commit is contained in:
committed by
PyTorch MergeBot
parent
b4160fd9c7
commit
dccc1ca839
20
test/jit/mydecorator.py
Normal file
20
test/jit/mydecorator.py
Normal file
@ -0,0 +1,20 @@
|
||||
r"""
|
||||
Decorator used in test_decorator.py. We define it in a
|
||||
separate file on purpose to test that the names in different modules
|
||||
are resolved correctly.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
|
||||
def my_decorator(func):
|
||||
"""Dummy decorator that removes itself when torchscripting"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# torch.jit.script() uses __prepare_scriptable__ to remove the decorator
|
||||
wrapped_func.__prepare_scriptable__ = lambda: func
|
||||
|
||||
return wrapped_func
|
13
test/jit/myfunction_a.py
Normal file
13
test/jit/myfunction_a.py
Normal file
@ -0,0 +1,13 @@
|
||||
"""
|
||||
Helper function used in test_decorator.py. We define it in a
|
||||
separate file on purpose to test that the names in different modules
|
||||
are resolved correctly.
|
||||
"""
|
||||
|
||||
from jit.mydecorator import my_decorator
|
||||
from jit.myfunction_b import my_function_b
|
||||
|
||||
|
||||
@my_decorator
|
||||
def my_function_a(x: float) -> float:
|
||||
return my_function_b(x) + 1
|
16
test/jit/myfunction_b.py
Normal file
16
test/jit/myfunction_b.py
Normal file
@ -0,0 +1,16 @@
|
||||
r"""
|
||||
Helper function used in test_decorator.py. We define it in a
|
||||
separate file on purpose to test that the names in different modules
|
||||
are resolved correctly.
|
||||
"""
|
||||
|
||||
from jit.mydecorator import my_decorator
|
||||
|
||||
|
||||
@my_decorator
|
||||
def my_function_b(x: float) -> float:
|
||||
return my_function_c(x) + 2
|
||||
|
||||
|
||||
def my_function_c(x: float) -> float:
|
||||
return x + 3
|
27
test/jit/test_decorator.py
Normal file
27
test/jit/test_decorator.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Owner(s): ["oncall: jit"]
|
||||
# flake8: noqa
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
from jit.myfunction_a import my_function_a
|
||||
|
||||
|
||||
class TestDecorator(JitTestCase):
|
||||
def test_decorator(self):
|
||||
# Note: JitTestCase.checkScript() does not work with decorators
|
||||
# self.checkScript(my_function_a, (1.0,))
|
||||
# Error:
|
||||
# RuntimeError: expected def but found '@' here:
|
||||
# @my_decorator
|
||||
# ~ <--- HERE
|
||||
# def my_function_a(x: float) -> float:
|
||||
# Do a simple torch.jit.script() test instead
|
||||
fn = my_function_a
|
||||
fx = torch.jit.script(fn)
|
||||
self.assertEqual(fn(1.0), fx(1.0))
|
@ -996,6 +996,10 @@ def try_compile_fn(fn, loc):
|
||||
f"Consider manually annotating `{fn}` with @torch.jit.script."
|
||||
)
|
||||
|
||||
# The object returned by __prepare_scriptable__ might have a different closure.
|
||||
# Resolve it here to get the right resolution callback.
|
||||
fn = fn.__prepare_scriptable__() if hasattr(fn, "__prepare_scriptable__") else fn # type: ignore[operator]
|
||||
|
||||
# We don't have the actual scope where the function was defined, but we can
|
||||
# extract the necessary info from the closed over variables on the function
|
||||
# object
|
||||
|
Reference in New Issue
Block a user