Compare commits

...

4 Commits

Author SHA1 Message Date
a2d2a30311 Add torch._dynamo.config.fail_on_cache_limit_hit (#136767)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136767
Approved by: https://github.com/albanD, https://github.com/jansel
ghstack dependencies: #136533
2024-09-27 03:58:00 +00:00
2521cd3874 Skip kernel saving if already existed. (#136389)
Summary:
We skip the save_gpu_kernel if kernel is being saved already.
This would give us a more accurate Triton profiling result. The following trace shows before/after the change for a benchmarking of a trivial addmm:

Before:
<img width="1255" alt="Screenshot 2024-09-23 at 10 26 53 AM" src="https://github.com/user-attachments/assets/5aea05ef-6ef0-464c-8da9-17b31c97b43a">

After:
<img width="910" alt="Screenshot 2024-09-23 at 10 27 03 AM" src="https://github.com/user-attachments/assets/488b7d4f-268f-41cf-8553-cb16ceeae118">

We can see that before the change, the benchmarking includes two parts,
(1) The overhead of our triton_heuristic call, which includes the save/get, and the (expensive) hash computation.
(2) The exact computation of Triton kernel.

We see that (1) accounts >50% of time, which makes kernel selection for profiling often choose aten kernels over Triton kernels.

Test Plan:
Existing OSS CI
[Redacted, Some internal model results in D63441430]

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136389
Approved by: https://github.com/desertfire
2024-09-27 03:03:28 +00:00
d1382aaf3d skip test_out_of_memory for jetson (#133270)
Skip test_out_of_memory in test/test_cuda.py on Jetson as OOM reporting in Jetson has issues due to partially missing NVML support. cc @eqy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133270
Approved by: https://github.com/eqy, https://github.com/albanD, https://github.com/seemethere
2024-09-27 02:36:48 +00:00
26869d38e1 [Inductor] Further solve missing aoti_torch_check symbole issue (#136775)
Summary: https://github.com/pytorch/pytorch/pull/136669 didn't resolve all the internal test failures. Add more tests to OSS CI to catch the remaining issues, and fix some internal TARGETS dependency.

Differential Revision: D63473744

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136775
Approved by: https://github.com/henrylhtsang
2024-09-27 02:26:49 +00:00
8 changed files with 40 additions and 3 deletions

View File

@ -375,9 +375,8 @@ test_inductor_cpp_wrapper_abi_compatible() {
mkdir -p "$TEST_REPORTS_DIR"
echo "Testing Inductor cpp wrapper mode with TORCHINDUCTOR_ABI_COMPATIBLE=1"
# cpu stack allocation causes segfault and needs more investigation
PYTORCH_TESTING_DEVICE_ONLY_FOR="" python test/run_test.py --include inductor/test_cpu_cpp_wrapper
python test/run_test.py --include inductor/test_cuda_cpp_wrapper
python test/run_test.py --include inductor/test_cuda_cpp_wrapper inductor/test_cpu_repro
TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \
--training --inductor --disable-cudagraphs --only vit_base_patch16_224 \

View File

@ -8,6 +8,7 @@ import torch._dynamo.config
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._logging
from torch._dynamo.exc import FailOnCacheLimitHit
from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings
@ -203,6 +204,19 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase):
"expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>",
)
@torch._dynamo.config.patch(cache_size_limit=1, fail_on_cache_limit_hit=True)
def test_fail_on_cache_limit_hit(self):
@torch.compile(backend="eager")
def func(b, a):
if a:
return b * 2
else:
return b + 1
func(torch.randn(5), True)
with self.assertRaises(FailOnCacheLimitHit):
func(torch.randn(5), False)
@torch._dynamo.config.patch("cache_size_limit", 32)
def test_multiple_guard_fails(self):
failure_reasons = []

View File

@ -232,6 +232,9 @@ class TestCuda(TestCase):
device_properties_no_argument = torch.cuda.get_device_properties()
self.assertEqual(current_device_properties, device_properties_no_argument)
@unittest.skipIf(
IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support"
)
def test_out_of_memory(self):
tensor = torch.zeros(1024, device="cuda")

View File

@ -51,6 +51,12 @@ accumulated_cache_size_limit = 256
# [@compile_ignored: runtime_behaviour] skip tracing recursively if cache limit is hit
skip_code_recursive_on_cache_limit_hit = True
# raise a hard error if cache limit is hit. If you are on a model where you
# know you've sized the cache correctly, this can help detect problems when
# you regress guards/specialization. This works best when cache_size_limit = 1.
# [@compile_ignored: runtime_behaviour]
fail_on_cache_limit_hit = False
# whether or not to specialize on int inputs. This only has an effect with
# dynamic_shapes; when dynamic_shapes is False, we ALWAYS specialize on int
# inputs. Note that assume_static_by_default will also cause ints to get

View File

@ -70,6 +70,7 @@ from .exc import (
augment_exc_message,
BackendCompilerFailed,
CacheLimitExceeded,
FailOnCacheLimitHit,
format_error_msg,
InternalTorchDynamoError,
SkipCodeRecursiveException,
@ -858,7 +859,11 @@ def _compile(
format_guard_failures(),
troubleshooting_url,
)
if config.skip_code_recursive_on_cache_limit_hit and justknobs_check(
if config.fail_on_cache_limit_hit:
raise FailOnCacheLimitHit(
f"{limit_type} reached, because fail_on_cache_limit_hit = True this is a HARD failure"
)
elif config.skip_code_recursive_on_cache_limit_hit and justknobs_check(
"pytorch/compiler:skip_code_recursive_on_cache_limit_hit"
):
raise CacheLimitExceeded(f"{limit_type} reached")

View File

@ -188,6 +188,13 @@ class IncorrectUsage(Exception):
pass
# TODO: I'm a little uncertain about what error classification we should have
# for this. This is potentially a user error, but regressions in
# specialization in PyTorch proper could also trigger this problem
class FailOnCacheLimitHit(Exception):
pass
class ObservedException(TorchDynamoException):
# An exception observed during the tracing. This exception is used by Dynamo to handle exceptions.
pass

View File

@ -26,6 +26,7 @@
#include <c10/util/generic_math.h>
#include <c10/util/Half.h>
#include <c10/util/TypeCast.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX)
#define INDUCTOR_USE_VECTOR_TYPES() 1

View File

@ -753,6 +753,8 @@ class CachingAutotuner(KernelInterface):
self.save_cache_hook(self.launchers[0].config, self.autotune_time_taken_ns)
def save_gpu_kernel(self, grid, stream, launcher):
if self.cuda_kernel_saved:
return
if callable(grid):
grid_x, grid_y, grid_z = grid(launcher.config.kwargs)
else: