mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] replace import pathlib
with from pathlib import Path
(#129426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129426 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
7837a12474
commit
6d75604ef1
@ -1,40 +1,58 @@
|
||||
# Owner(s): ["module: serialization"]
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
import io
|
||||
import tempfile
|
||||
import os
|
||||
import gc
|
||||
import sys
|
||||
import zipfile
|
||||
import warnings
|
||||
import gzip
|
||||
import copy
|
||||
import gc
|
||||
import gzip
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import pathlib
|
||||
import platform
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections import namedtuple, OrderedDict
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from pathlib import Path
|
||||
|
||||
from torch._utils_internal import get_file_path_2
|
||||
import torch
|
||||
from torch._utils import _rebuild_tensor
|
||||
from torch.utils._import_utils import import_dill
|
||||
from torch.serialization import check_module_version_greater_or_equal, get_default_load_endianness, \
|
||||
set_default_load_endianness, LoadEndianness, SourceChangeWarning
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName,
|
||||
TestCase, IS_FBCODE, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName,
|
||||
parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest, skipIfTorchDynamo)
|
||||
from torch._utils_internal import get_file_path_2
|
||||
from torch.serialization import (
|
||||
check_module_version_greater_or_equal,
|
||||
get_default_load_endianness,
|
||||
LoadEndianness,
|
||||
set_default_load_endianness,
|
||||
SourceChangeWarning,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
from torch.testing._internal.common_utils import (
|
||||
AlwaysWarnTypedStorageRemoval,
|
||||
BytesIOContext,
|
||||
download_file,
|
||||
instantiate_parametrized_tests,
|
||||
IS_FBCODE,
|
||||
IS_FILESYSTEM_UTF8_ENCODING,
|
||||
IS_WINDOWS,
|
||||
parametrize,
|
||||
run_tests,
|
||||
serialTest,
|
||||
skipIfTorchDynamo,
|
||||
TemporaryDirectoryName,
|
||||
TemporaryFileName,
|
||||
TEST_DILL,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.two_tensor import TwoTensor # noqa: F401
|
||||
from torch.utils._import_utils import import_dill
|
||||
|
||||
|
||||
if not IS_WINDOWS:
|
||||
from mmap import MAP_SHARED, MAP_PRIVATE
|
||||
from mmap import MAP_PRIVATE, MAP_SHARED
|
||||
else:
|
||||
MAP_SHARED, MAP_PRIVATE = None, None
|
||||
|
||||
@ -988,7 +1006,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
model = torch.nn.Conv2d(20, 3200, kernel_size=3)
|
||||
|
||||
with TemporaryFileName() as fname:
|
||||
path = pathlib.Path(fname)
|
||||
path = Path(fname)
|
||||
torch.save(model.state_dict(), path)
|
||||
torch.load(path, weights_only=weights_only)
|
||||
|
||||
@ -4008,7 +4026,7 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
finally:
|
||||
set_default_load_endianness(current_load_endian)
|
||||
|
||||
@parametrize('path_type', (str, pathlib.Path))
|
||||
@parametrize('path_type', (str, Path))
|
||||
@parametrize('weights_only', (True, False))
|
||||
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
|
||||
def test_serialization_mmap_loading(self, weights_only, path_type):
|
||||
|
Reference in New Issue
Block a user