[JIT] Fix typed enum handling in 3.11 (#109717)

In Python-3.11+ typed enums (such as `enum.IntEnum`) retain `__new__`,`__str__` and so on method of the base class via `__init__subclass__()` method (see https://docs.python.org/3/whatsnew/3.11.html#enum ), i.e. following code
```python
import sys
import inspect
from enum import Enum

class IntColor(int, Enum):
    RED = 1
    GREEN = 2

class Color(Enum):
    RED = 1
    GREEN = 2

def get_methods(cls):
    def predicate(m):
        if not inspect.isfunction(m) and not inspect.ismethod(m):
            return False
        return m.__name__ in cls.__dict__
    return inspect.getmembers(cls, predicate=predicate)

if __name__ == "__main__":
    print(sys.version)
    print(f"IntColor methods {get_methods(IntColor)}")
    print(f"Color methods {get_methods(Color)}")
```

Returns empty list for both cases for older Python, but on Python-3.11+ it returns list contains of enum constructors and others:
```shell
% conda run -n py310 python bar.py
3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:41:52) [Clang 15.0.7 ]
IntColor methods []
Color methods []
% conda run -n py311 python bar.py
3.11.0 | packaged by conda-forge | (main, Oct 25 2022, 06:21:25) [Clang 14.0.4 ]
IntColor methods [('__format__', <function Enum.__format__ at 0x105006ac0>), ('__new__', <function Enum.__new__ at 0x105006660>), ('__repr__', <function Enum.__repr__ at 0x1050068e0>)]
Color methods []
```

This change allows typed enums to be scriptable on 3.11, by explicitly marking several `enum.Enum` method to be dropped by jit script and adds test that typed enums are jit-scriptable.

Fixes https://github.com/pytorch/pytorch/issues/108933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109717
Approved by: https://github.com/atalman, https://github.com/davidberard98
This commit is contained in:
Nikita Shulga
2023-09-20 22:09:38 +00:00
committed by PyTorch MergeBot
parent 7ce69d5dbe
commit 55685d57c0
2 changed files with 21 additions and 1 deletions

View File

@ -362,3 +362,13 @@ class TestEnum(JitTestCase):
GREEN = 2
torch.jit.script(Color)
# Regression test for https://github.com/pytorch/pytorch/issues/108933
def test_typed_enum(self):
class Color(int, Enum):
RED = 1
GREEN = 2
@torch.jit.script
def is_red(x: Color) -> bool:
return x == Color.RED

View File

@ -450,7 +450,7 @@ def createResolutionCallbackForClassMethods(cls):
# Skip built-ins, as they do not have global scope nor type hints
# Needed to support `enum.Enum` derived classes in Python-3.11
# That adds `_new_member_` property which is an alias to `__new__`
fns = [fn for fn in fns if not inspect.isbuiltin(fn)]
fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
captures = {}
for fn in fns:
@ -1491,3 +1491,13 @@ def _extract_tensors(obj):
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
extractor.dump(obj)
return tensors
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
# that were previously dropped. To preserve the behavior, explicitly drop them there
if sys.version_info > (3, 10):
_drop(enum.Enum.__new__)
_drop(enum.Enum.__format__)
_drop(enum.Enum.__repr__)
_drop(enum.Enum.__str__)