Files
pytorch/torch/package/_package_pickler.py
Xuehai Pan 7837a12474 [BE] enforce style for empty lines in import segments (#129751)
This PR follows https://github.com/pytorch/pytorch/pull/129374#pullrequestreview-2136555775 cc @malfet:

> Lots of formatting changes unrelated to PR goal, please keep them as part of separate PR (and please add lint rule if you want to enforce those, or at least cite one)

`usort` allows empty lines within import segments. For example, `usort` do not change the following code:

```python
import torch.aaa
import torch.bbb
import torch.ccc

x = ...  # some code
```

```python
import torch.aaa

import torch.bbb
import torch.ccc

x = ...  # some code
```

```python
import torch.aaa

import torch.bbb

import torch.ccc

x = ...  # some code
```

This PR first sort imports via `isort`, then re-sort the file using `ufmt` (`usort` + `black`). This enforces the following import style:

1. no empty lines within segments.
2. single empty line between segments.
3. two spaces after import statements.

All the code snippets above will be formatted to:

```python
import torch.aaa
import torch.bbb
import torch.ccc

x = ...  # some code
```

which produces a consistent code style.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129751
Approved by: https://github.com/malfet
2024-06-29 14:15:24 +00:00

119 lines
4.5 KiB
Python

# mypy: allow-untyped-defs
from pickle import ( # type: ignore[attr-defined]
_compat_pickle,
_extension_registry,
_getattribute,
_Pickler,
EXT1,
EXT2,
EXT4,
GLOBAL,
Pickler,
PicklingError,
STACK_GLOBAL,
)
from struct import pack
from types import FunctionType
from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
class PackagePickler(_Pickler):
"""Package-aware pickler.
This behaves the same as a normal pickler, except it uses an `Importer`
to find objects and modules to save.
"""
def __init__(self, importer: Importer, *args, **kwargs):
self.importer = importer
super().__init__(*args, **kwargs)
# Make sure the dispatch table copied from _Pickler is up-to-date.
# Previous issues have been encountered where a library (e.g. dill)
# mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
# is imported, then the offending library removes its dispatch entries,
# leaving PackagePickler with a stale dispatch table that may cause
# unwanted behavior.
self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
def save_global(self, obj, name=None):
# unfortunately the pickler code is factored in a way that
# forces us to copy/paste this function. The only change is marked
# CHANGED below.
write = self.write # type: ignore[attr-defined]
memo = self.memo # type: ignore[attr-defined]
# CHANGED: import module from module environment instead of __import__
try:
module_name, name = self.importer.get_name(obj, name)
except (ObjNotFoundError, ObjMismatchError) as err:
raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
module = self.importer.import_module(module_name)
_, parent = _getattribute(module, name)
# END CHANGED
if self.proto >= 2: # type: ignore[attr-defined]
code = _extension_registry.get((module_name, name))
if code:
assert code > 0
if code <= 0xFF:
write(EXT1 + pack("<B", code))
elif code <= 0xFFFF:
write(EXT2 + pack("<H", code))
else:
write(EXT4 + pack("<i", code))
return
lastname = name.rpartition(".")[2]
if parent is module:
name = lastname
# Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 4: # type: ignore[attr-defined]
self.save(module_name) # type: ignore[attr-defined]
self.save(name) # type: ignore[attr-defined]
write(STACK_GLOBAL)
elif parent is not module:
self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
elif self.proto >= 3: # type: ignore[attr-defined]
write(
GLOBAL
+ bytes(module_name, "utf-8")
+ b"\n"
+ bytes(name, "utf-8")
+ b"\n"
)
else:
if self.fix_imports: # type: ignore[attr-defined]
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
if (module_name, name) in r_name_mapping:
module_name, name = r_name_mapping[(module_name, name)]
elif module_name in r_import_mapping:
module_name = r_import_mapping[module_name]
try:
write(
GLOBAL
+ bytes(module_name, "ascii")
+ b"\n"
+ bytes(name, "ascii")
+ b"\n"
)
except UnicodeEncodeError:
raise PicklingError(
"can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
) from None
self.memoize(obj) # type: ignore[attr-defined]
def create_pickler(data_buf, importer, protocol=4):
if importer is sys_importer:
# if we are using the normal import library system, then
# we can use the C implementation of pickle which is faster
return Pickler(data_buf, protocol=protocol)
else:
return PackagePickler(importer, data_buf, protocol=protocol)