mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Enable xdoctest runner in CI for real this time (#83816)
Builds on #83317 and enables running the doctests. Just need to figure out what is causing the failures. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83816 Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
fb4fc0dabe
commit
ad782ff7df
@ -49,6 +49,7 @@ __all__ = [
|
||||
'StorageType',
|
||||
]
|
||||
|
||||
|
||||
class SourceChangeWarning(Warning):
|
||||
pass
|
||||
|
||||
@ -186,10 +187,12 @@ def _cuda_deserialize(obj, location):
|
||||
else:
|
||||
return obj.cuda(device)
|
||||
|
||||
|
||||
def _mps_deserialize(obj, location):
|
||||
if location == 'mps':
|
||||
return obj.mps()
|
||||
|
||||
|
||||
def _meta_deserialize(obj, location):
|
||||
if location == 'meta':
|
||||
return torch.UntypedStorage(obj.nbytes(), device='meta')
|
||||
@ -356,6 +359,7 @@ def _check_seekable(f) -> bool:
|
||||
raise_err_msg(["seek", "tell"], e)
|
||||
return False
|
||||
|
||||
|
||||
def _check_dill_version(pickle_module) -> None:
|
||||
'''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
|
||||
If dill version is lower than 0.3.1, a ValueError is raised.
|
||||
@ -375,12 +379,14 @@ def _check_dill_version(pickle_module) -> None:
|
||||
pickle_module.__version__
|
||||
))
|
||||
|
||||
|
||||
def _check_save_filelike(f):
|
||||
if not isinstance(f, (str, os.PathLike)) and not hasattr(f, 'write'):
|
||||
raise AttributeError((
|
||||
"expected 'f' to be string, path, or a file-like object with "
|
||||
"a 'write' attribute"))
|
||||
|
||||
|
||||
def save(
|
||||
obj: object,
|
||||
f: FILE_LIKE,
|
||||
@ -420,6 +426,7 @@ def save(
|
||||
to use the old format, pass the kwarg ``_use_new_zipfile_serialization=False``.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP("makes cwd dirty")
|
||||
>>> # Save to file
|
||||
>>> x = torch.tensor([0, 1, 2, 3, 4])
|
||||
>>> torch.save(x, 'tensor.pt')
|
||||
@ -753,7 +760,7 @@ def load(
|
||||
# Load all tensors onto GPU 1
|
||||
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
|
||||
# Map tensors from GPU 1 to GPU 0
|
||||
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
|
||||
>>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'})
|
||||
# Load tensor from io.BytesIO object
|
||||
>>> with open('tensor.pt', 'rb') as f:
|
||||
... buffer = io.BytesIO(f.read())
|
||||
@ -1087,6 +1094,7 @@ def _get_restore_location(map_location):
|
||||
return result
|
||||
return restore_location
|
||||
|
||||
|
||||
class StorageType():
|
||||
def __init__(self, name):
|
||||
self.dtype = _get_dtype_from_pickle_storage_type(name)
|
||||
@ -1094,6 +1102,7 @@ class StorageType():
|
||||
def __str__(self):
|
||||
return f'StorageType(dtype={self.dtype})'
|
||||
|
||||
|
||||
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
|
||||
restore_location = _get_restore_location(map_location)
|
||||
|
||||
|
Reference in New Issue
Block a user