mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This fixes a case left incomplete by https://github.com/pytorch/pytorch/pull/106229 The object is using __prepare_scriptable__ correctly inside of torch.jit.script() but the clousre that is obtained below is using the non-prepared version. This causes issues when the prepared and non-prepared versions are in different python modules. Test Plan: ``` buck2 run mode/opt caffe2/test:jit -- -r test_decorator ``` Differential Revision: D54308741 Re-exporting, as #120806 #121307 were not properly merged. Co-authored-by: Daniel Herrera <dherrera@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121553 Approved by: https://github.com/huydhn, https://github.com/seemethere
28 lines
765 B
Python
28 lines
765 B
Python
# Owner(s): ["oncall: jit"]
|
|
# flake8: noqa
|
|
|
|
import sys
|
|
import unittest
|
|
from enum import Enum
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
from jit.myfunction_a import my_function_a
|
|
|
|
|
|
class TestDecorator(JitTestCase):
|
|
def test_decorator(self):
|
|
# Note: JitTestCase.checkScript() does not work with decorators
|
|
# self.checkScript(my_function_a, (1.0,))
|
|
# Error:
|
|
# RuntimeError: expected def but found '@' here:
|
|
# @my_decorator
|
|
# ~ <--- HERE
|
|
# def my_function_a(x: float) -> float:
|
|
# Do a simple torch.jit.script() test instead
|
|
fn = my_function_a
|
|
fx = torch.jit.script(fn)
|
|
self.assertEqual(fn(1.0), fx(1.0))
|