[BE][Easy] replace import pathlib with from pathlib import Path (#129426)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129426
Approved by: https://github.com/malfet
This commit is contained in:
Xuehai Pan
2024-06-29 23:35:02 +08:00
committed by PyTorch MergeBot
parent 7837a12474
commit 6d75604ef1
33 changed files with 159 additions and 140 deletions

View File

@ -1,40 +1,58 @@
# Owner(s): ["module: serialization"]
import torch
import unittest
import io
import tempfile
import os
import gc
import sys
import zipfile
import warnings
import gzip
import copy
import gc
import gzip
import io
import os
import pickle
import shutil
import pathlib
import platform
import shutil
import sys
import tempfile
import unittest
import warnings
import zipfile
from collections import namedtuple, OrderedDict
from copy import deepcopy
from itertools import product
from pathlib import Path
from torch._utils_internal import get_file_path_2
import torch
from torch._utils import _rebuild_tensor
from torch.utils._import_utils import import_dill
from torch.serialization import check_module_version_greater_or_equal, get_default_load_endianness, \
set_default_load_endianness, LoadEndianness, SourceChangeWarning
from torch.testing._internal.common_utils import (
IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName,
TestCase, IS_FBCODE, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName,
parametrize, instantiate_parametrized_tests, AlwaysWarnTypedStorageRemoval, serialTest, skipIfTorchDynamo)
from torch._utils_internal import get_file_path_2
from torch.serialization import (
check_module_version_greater_or_equal,
get_default_load_endianness,
LoadEndianness,
set_default_load_endianness,
SourceChangeWarning,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing._internal.common_utils import (
AlwaysWarnTypedStorageRemoval,
BytesIOContext,
download_file,
instantiate_parametrized_tests,
IS_FBCODE,
IS_FILESYSTEM_UTF8_ENCODING,
IS_WINDOWS,
parametrize,
run_tests,
serialTest,
skipIfTorchDynamo,
TemporaryDirectoryName,
TemporaryFileName,
TEST_DILL,
TestCase,
)
from torch.testing._internal.two_tensor import TwoTensor # noqa: F401
from torch.utils._import_utils import import_dill
if not IS_WINDOWS:
from mmap import MAP_SHARED, MAP_PRIVATE
from mmap import MAP_PRIVATE, MAP_SHARED
else:
MAP_SHARED, MAP_PRIVATE = None, None
@ -988,7 +1006,7 @@ class TestSerialization(TestCase, SerializationMixin):
model = torch.nn.Conv2d(20, 3200, kernel_size=3)
with TemporaryFileName() as fname:
path = pathlib.Path(fname)
path = Path(fname)
torch.save(model.state_dict(), path)
torch.load(path, weights_only=weights_only)
@ -4008,7 +4026,7 @@ class TestSerialization(TestCase, SerializationMixin):
finally:
set_default_load_endianness(current_load_endian)
@parametrize('path_type', (str, pathlib.Path))
@parametrize('path_type', (str, Path))
@parametrize('weights_only', (True, False))
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
def test_serialization_mmap_loading(self, weights_only, path_type):