mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
cpp_wrapper: use largeTensorTest for test memory checks (#146991)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146991 Approved by: https://github.com/desertfire
This commit is contained in:
committed by
PyTorch MergeBot
parent
723f3a9eab
commit
f98cd84b04
@ -462,7 +462,7 @@ class TestFxGraphCache(TestCase):
|
||||
# And the results should be the same.
|
||||
self.assertEqual(grads1, grads2)
|
||||
|
||||
@largeTensorTest("64GB", device=GPU_TYPE)
|
||||
@largeTensorTest("64GB", device=GPU_TYPE, inductor=True)
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@parametrize("device", (GPU_TYPE,))
|
||||
|
@ -95,7 +95,7 @@ class TestMetrics(TestCase):
|
||||
kernel_code = kernel_list[0]
|
||||
self.assertEqual(metrics._count_pattern(kernel_code, "tl.atomic_add"), 1)
|
||||
|
||||
@largeTensorTest(25e7 * 2 * 4, device=GPU_TYPE)
|
||||
@largeTensorTest(25e7 * 2 * 4, device=GPU_TYPE, inductor=True)
|
||||
@config.patch("fx_graph_remote_cache", False)
|
||||
@config.patch("benchmark_kernel", True)
|
||||
def test_kernel_args_num_gb(self):
|
||||
|
@ -75,8 +75,8 @@ from torch.testing._internal.common_cuda import (
|
||||
with_tf32_off,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
_has_sufficient_memory,
|
||||
expectedFailureXPU,
|
||||
largeTensorTest,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types, get_all_dtypes
|
||||
from torch.testing._internal.common_quantization import (
|
||||
@ -773,6 +773,16 @@ def is_cpp_backend(device):
|
||||
return getattr(device, "type", device) == "cpu" and config.cpu_backend == "cpp"
|
||||
|
||||
|
||||
def skip_if_cpu(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self):
|
||||
if self.device == "cpu":
|
||||
raise unittest.SkipTest("cpu not supported")
|
||||
return fn(self)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def skip_if_halide(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self):
|
||||
@ -3329,18 +3339,10 @@ class CommonTemplate:
|
||||
|
||||
self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
|
||||
|
||||
@skip_if_cpu
|
||||
@skip_if_halide # only 32-bit indexing
|
||||
@largeTensorTest("4GB", inductor=True)
|
||||
def test_large_tensor_reduction(self):
|
||||
if self.device == "cpu":
|
||||
raise unittest.SkipTest("Fails on CPU")
|
||||
|
||||
# If this is running with cpp_wrapper, the auto-tuning step will generate an
|
||||
# additional array of the same size as the input. Numbers derived
|
||||
# experimentally.
|
||||
required_memory = 2**33 if config.cpp_wrapper else 2**32 + 2**16
|
||||
if not _has_sufficient_memory(self.device, required_memory):
|
||||
raise unittest.SkipTest("insufficient memory")
|
||||
|
||||
# Test 64-bit indexing works correctly
|
||||
def fn(a):
|
||||
return torch.max(a)
|
||||
@ -3354,11 +3356,9 @@ class CommonTemplate:
|
||||
expect = torch.tensor(2, dtype=torch.int8, device=self.device)
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
@skip_if_cpu
|
||||
@skip_if_gpu_halide # only 32-bit indexing
|
||||
def test_large_broadcast_reduction(self):
|
||||
if self.device == "cpu":
|
||||
raise unittest.SkipTest("Fails on CPU")
|
||||
|
||||
# Test 64-bit indexing works correctly when inputs are less than 32-bit
|
||||
# but intermediate tensors require 64-bit indexing
|
||||
def fn(a, b):
|
||||
@ -3377,16 +3377,8 @@ class CommonTemplate:
|
||||
self.assertEqual(actual, expect)
|
||||
|
||||
@skip_if_halide # only 32-bit indexing
|
||||
@largeTensorTest("4GB", inductor=True)
|
||||
def test_large_pointwise(self):
|
||||
# If this is running with cpp_wrapper, the auto-tuning step will generate an
|
||||
# additional array of the same size as the input. Numbers derived
|
||||
# experimentally.
|
||||
required_memory = (
|
||||
2**32 + 2**31 + 2**15 if config.cpp_wrapper else 2**31 + 2**15
|
||||
)
|
||||
if not _has_sufficient_memory(self.device, required_memory):
|
||||
raise unittest.SkipTest("insufficient memory")
|
||||
|
||||
def fn(a):
|
||||
return a + 1
|
||||
|
||||
@ -3402,16 +3394,11 @@ class CommonTemplate:
|
||||
self.assertTrue((actual == 2).all())
|
||||
|
||||
@skip_if_halide # only 32-bit indexing
|
||||
@largeTensorTest("3GB", inductor=True)
|
||||
def test_large_offset_pointwise(self):
|
||||
# Test 64-bit indexing is used when input views a tensor that can be
|
||||
# indexed with 32-bit strides but the storage offset pushes it over
|
||||
# INT_MAX
|
||||
|
||||
# Memory requirements derived experimentally.
|
||||
required_memory = 2**32 + 2**16
|
||||
if not _has_sufficient_memory(self.device, required_memory):
|
||||
raise unittest.SkipTest("insufficient memory")
|
||||
|
||||
def fn(a):
|
||||
return a + 4
|
||||
|
||||
@ -3422,17 +3409,10 @@ class CommonTemplate:
|
||||
self.assertTrue((actual == 4).all())
|
||||
|
||||
@skip_if_halide # only 32-bit indexing
|
||||
@largeTensorTest("2GB", inductor=True)
|
||||
def test_large_strided_reduction(self):
|
||||
# Test 64-bit indexing is used when input numel is less than INT_MAX
|
||||
# but stride calculations go above INT_MAX
|
||||
|
||||
# If this is running with cpp_wrapper, the auto-tuning step will generate an
|
||||
# additional array of the same size as the input. Numbers derived
|
||||
# experimentally.
|
||||
required_memory = 2**32 + 2**16 if config.cpp_wrapper else 2**31 + 2**16
|
||||
if not _has_sufficient_memory(self.device, required_memory):
|
||||
raise unittest.SkipTest("insufficient memory")
|
||||
|
||||
def fn(a):
|
||||
return torch.max(a)
|
||||
|
||||
@ -11775,6 +11755,7 @@ class CommonTemplate:
|
||||
"triton.autotune_pointwise", True
|
||||
) # needed to introduce config that exceed max shared memory usage
|
||||
@serialTest()
|
||||
@largeTensorTest("13GB", inductor=True)
|
||||
def test_large_block_sizes(self):
|
||||
"""
|
||||
Inductor will try triton configs like x = 64 and y = 1024 which will
|
||||
@ -11783,16 +11764,6 @@ class CommonTemplate:
|
||||
Currently inductor will skip such bad configs and pick the best one
|
||||
from the remaining configs.
|
||||
"""
|
||||
# If this is running with cpp_wrapper, the auto-tuning step will generate an
|
||||
# additional array of the same size as the input. Numbers derived
|
||||
# experimentally.
|
||||
required_memory = (
|
||||
2**34 + 2**32 + 2**31
|
||||
if config.cpp_wrapper
|
||||
else 2**33 + 2**32 + 2**31
|
||||
)
|
||||
if not _has_sufficient_memory(self.device, required_memory):
|
||||
raise unittest.SkipTest("insufficient memory")
|
||||
|
||||
@torch.compile
|
||||
def fn(x, y):
|
||||
@ -12220,16 +12191,8 @@ class CommonTemplate:
|
||||
t = rand_strided((2, 3), (3, 1), device=self.device, dtype=torch.float8_e4m3fn)
|
||||
self.assertTrue(t.dtype is torch.float8_e4m3fn)
|
||||
|
||||
@largeTensorTest("1GB", inductor=True)
|
||||
def test_large_grid(self):
|
||||
# If this is running with cpp_wrapper, the auto-tuning step will generate an
|
||||
# additional array of the same size as the input. Numbers derived
|
||||
# experimentally.
|
||||
required_memory = (
|
||||
2**30 + 2**29 + 2**15 if config.cpp_wrapper else 2**30 + 2**15
|
||||
)
|
||||
if not _has_sufficient_memory(self.device, required_memory):
|
||||
raise unittest.SkipTest("insufficient memory")
|
||||
|
||||
# https://github.com/pytorch/pytorch/issues/123210
|
||||
def fn(primals_5):
|
||||
view = torch.ops.aten.reshape.default(primals_5, [-1, 2, 4])
|
||||
|
@ -1779,8 +1779,7 @@ else:
|
||||
self.assertEqual(res0, res_cpu, atol=1e-3, rtol=1e-2)
|
||||
|
||||
@onlyCUDA
|
||||
@largeTensorTest('24GB', device='cuda')
|
||||
@largeTensorTest('24GB', device='cpu')
|
||||
@largeTensorTest('49GB')
|
||||
def test_cumsum_64bit_indexing(self, device):
|
||||
b = torch.ones(2 * 4096 * 8, 100000, dtype=torch.float, device='cuda')
|
||||
b /= 100000
|
||||
|
@ -1338,7 +1338,7 @@ def _has_sufficient_memory(device, size):
|
||||
return psutil.virtual_memory().available >= effective_size
|
||||
|
||||
|
||||
def largeTensorTest(size, device=None):
|
||||
def largeTensorTest(size, device=None, inductor=TEST_WITH_TORCHINDUCTOR):
|
||||
"""Skip test if the device has insufficient memory to run the test
|
||||
|
||||
size may be a number of bytes, a string of the form "N GB", or a callable
|
||||
@ -1354,8 +1354,19 @@ def largeTensorTest(size, device=None):
|
||||
def inner(fn):
|
||||
@wraps(fn)
|
||||
def dep_fn(self, *args, **kwargs):
|
||||
size_bytes = size(self, *args, **kwargs) if callable(size) else size
|
||||
_device = device if device is not None else self.get_primary_device()
|
||||
size_bytes: int = size(self, *args, **kwargs) if callable(size) else size
|
||||
_device = device
|
||||
if _device is None:
|
||||
if hasattr(self, "get_primary_device"):
|
||||
_device = self.get_primary_device()
|
||||
else:
|
||||
_device = self.device
|
||||
|
||||
# If this is running with GPU cpp_wrapper, the autotuning step will generate
|
||||
# an additional array of the same size as the input.
|
||||
if inductor and torch._inductor.config.cpp_wrapper and _device != "cpu":
|
||||
size_bytes *= 2
|
||||
|
||||
if not _has_sufficient_memory(_device, size_bytes):
|
||||
raise unittest.SkipTest(f"Insufficient {_device} memory")
|
||||
|
||||
|
Reference in New Issue
Block a user