mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR follows https://github.com/pytorch/pytorch/pull/129374#pullrequestreview-2136555775 cc @malfet: > Lots of formatting changes unrelated to PR goal, please keep them as part of separate PR (and please add lint rule if you want to enforce those, or at least cite one) `usort` allows empty lines within import segments. For example, `usort` do not change the following code: ```python import torch.aaa import torch.bbb import torch.ccc x = ... # some code ``` ```python import torch.aaa import torch.bbb import torch.ccc x = ... # some code ``` ```python import torch.aaa import torch.bbb import torch.ccc x = ... # some code ``` This PR first sort imports via `isort`, then re-sort the file using `ufmt` (`usort` + `black`). This enforces the following import style: 1. no empty lines within segments. 2. single empty line between segments. 3. two spaces after import statements. All the code snippets above will be formatted to: ```python import torch.aaa import torch.bbb import torch.ccc x = ... # some code ``` which produces a consistent code style. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129751 Approved by: https://github.com/malfet
119 lines
4.5 KiB
Python
119 lines
4.5 KiB
Python
# mypy: allow-untyped-defs
|
|
from pickle import ( # type: ignore[attr-defined]
|
|
_compat_pickle,
|
|
_extension_registry,
|
|
_getattribute,
|
|
_Pickler,
|
|
EXT1,
|
|
EXT2,
|
|
EXT4,
|
|
GLOBAL,
|
|
Pickler,
|
|
PicklingError,
|
|
STACK_GLOBAL,
|
|
)
|
|
from struct import pack
|
|
from types import FunctionType
|
|
|
|
from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
|
|
|
|
|
|
class PackagePickler(_Pickler):
|
|
"""Package-aware pickler.
|
|
|
|
This behaves the same as a normal pickler, except it uses an `Importer`
|
|
to find objects and modules to save.
|
|
"""
|
|
|
|
def __init__(self, importer: Importer, *args, **kwargs):
|
|
self.importer = importer
|
|
super().__init__(*args, **kwargs)
|
|
|
|
# Make sure the dispatch table copied from _Pickler is up-to-date.
|
|
# Previous issues have been encountered where a library (e.g. dill)
|
|
# mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
|
|
# is imported, then the offending library removes its dispatch entries,
|
|
# leaving PackagePickler with a stale dispatch table that may cause
|
|
# unwanted behavior.
|
|
self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
|
|
self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
|
|
|
|
def save_global(self, obj, name=None):
|
|
# unfortunately the pickler code is factored in a way that
|
|
# forces us to copy/paste this function. The only change is marked
|
|
# CHANGED below.
|
|
write = self.write # type: ignore[attr-defined]
|
|
memo = self.memo # type: ignore[attr-defined]
|
|
|
|
# CHANGED: import module from module environment instead of __import__
|
|
try:
|
|
module_name, name = self.importer.get_name(obj, name)
|
|
except (ObjNotFoundError, ObjMismatchError) as err:
|
|
raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
|
|
|
|
module = self.importer.import_module(module_name)
|
|
_, parent = _getattribute(module, name)
|
|
# END CHANGED
|
|
|
|
if self.proto >= 2: # type: ignore[attr-defined]
|
|
code = _extension_registry.get((module_name, name))
|
|
if code:
|
|
assert code > 0
|
|
if code <= 0xFF:
|
|
write(EXT1 + pack("<B", code))
|
|
elif code <= 0xFFFF:
|
|
write(EXT2 + pack("<H", code))
|
|
else:
|
|
write(EXT4 + pack("<i", code))
|
|
return
|
|
lastname = name.rpartition(".")[2]
|
|
if parent is module:
|
|
name = lastname
|
|
# Non-ASCII identifiers are supported only with protocols >= 3.
|
|
if self.proto >= 4: # type: ignore[attr-defined]
|
|
self.save(module_name) # type: ignore[attr-defined]
|
|
self.save(name) # type: ignore[attr-defined]
|
|
write(STACK_GLOBAL)
|
|
elif parent is not module:
|
|
self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
|
|
elif self.proto >= 3: # type: ignore[attr-defined]
|
|
write(
|
|
GLOBAL
|
|
+ bytes(module_name, "utf-8")
|
|
+ b"\n"
|
|
+ bytes(name, "utf-8")
|
|
+ b"\n"
|
|
)
|
|
else:
|
|
if self.fix_imports: # type: ignore[attr-defined]
|
|
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
|
|
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
|
|
if (module_name, name) in r_name_mapping:
|
|
module_name, name = r_name_mapping[(module_name, name)]
|
|
elif module_name in r_import_mapping:
|
|
module_name = r_import_mapping[module_name]
|
|
try:
|
|
write(
|
|
GLOBAL
|
|
+ bytes(module_name, "ascii")
|
|
+ b"\n"
|
|
+ bytes(name, "ascii")
|
|
+ b"\n"
|
|
)
|
|
except UnicodeEncodeError:
|
|
raise PicklingError(
|
|
"can't pickle global identifier '%s.%s' using "
|
|
"pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
|
|
) from None
|
|
|
|
self.memoize(obj) # type: ignore[attr-defined]
|
|
|
|
|
|
def create_pickler(data_buf, importer, protocol=4):
|
|
if importer is sys_importer:
|
|
# if we are using the normal import library system, then
|
|
# we can use the C implementation of pickle which is faster
|
|
return Pickler(data_buf, protocol=protocol)
|
|
else:
|
|
return PackagePickler(importer, data_buf, protocol=protocol)
|