mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix code descriptions in the test package. (#148145)
The parameter and function description have something wrong and make them correct. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148145 Approved by: https://github.com/janeyx99
This commit is contained in:
@ -953,7 +953,7 @@ class TensorLikePair(Pair):
|
||||
),
|
||||
)
|
||||
|
||||
# Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formates can be `torch.int32` _or_
|
||||
# Compressed and plain indices in the CSR / CSC / BSR / BSC sparse formats can be `torch.int32` _or_
|
||||
# `torch.int64`. While the same dtype is enforced for the compressed and plain indices of a single tensor, it
|
||||
# can be different between two tensors. Thus, we need to convert them to the same dtype, or the comparison will
|
||||
# fail.
|
||||
|
@ -222,7 +222,7 @@ class ModuleInfo:
|
||||
# channels last output
|
||||
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
|
||||
module_error_inputs_func=None, # Function to generate module inputs that error
|
||||
gradcheck_fast_mode=None, # Whether to use the fast implmentation for gradcheck/gradgradcheck.
|
||||
gradcheck_fast_mode=None, # Whether to use the fast implementation for gradcheck/gradgradcheck.
|
||||
# When set to None, defers to the default value provided by the wrapper
|
||||
# function around gradcheck (testing._internal.common_utils.gradcheck)
|
||||
):
|
||||
@ -3575,7 +3575,7 @@ module_db: list[ModuleInfo] = [
|
||||
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
|
||||
# Failure on ROCM for float32 issue #70125
|
||||
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
|
||||
# Not implmented for chalf on CPU
|
||||
# Not implemented for chalf on CPU
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
|
||||
dtypes=(torch.chalf,), device_type='cuda'),
|
||||
# See #119108: MPSNDArrayConvolutionA14.mm:3976: failed assertion `destination datatype must be fp32'
|
||||
@ -3640,7 +3640,7 @@ module_db: list[ModuleInfo] = [
|
||||
# These fail only on ROCm
|
||||
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
|
||||
dtypes=[torch.complex32, torch.complex64], active_if=TEST_WITH_ROCM),
|
||||
# Not implmented for chalf on CPU
|
||||
# Not implemented for chalf on CPU
|
||||
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity',
|
||||
dtypes=(torch.chalf,), device_type='cuda'),
|
||||
),
|
||||
|
@ -45,7 +45,7 @@ def freeze_rng_state():
|
||||
# In the long run torch.cuda.set_rng_state should probably be
|
||||
# an operator.
|
||||
#
|
||||
# NB: Mode disable is to avoid running cross-ref tests on thes seeding
|
||||
# NB: Mode disable is to avoid running cross-ref tests on this seeding
|
||||
with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
|
||||
|
Reference in New Issue
Block a user