From 73622fc5fa9713f46a5cef9704772e645591bce6 Mon Sep 17 00:00:00 2001 From: wengshiy Date: Sun, 26 Jan 2025 03:37:18 +0000 Subject: [PATCH] Fix Throughputbenchmark issue (#144669) Fixes [144461](https://github.com/pytorch/pytorch/issues/144461) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144669 Approved by: https://github.com/leslie-fang-intel, https://github.com/williamwen42, https://github.com/jansel --- test/test_throughput_benchmark.py | 48 +++++++++++++++++++++ torch/csrc/dynamo/guards.cpp | 4 ++ torch/csrc/utils/throughput_benchmark-inl.h | 14 ++++++ 3 files changed, 66 insertions(+) diff --git a/test/test_throughput_benchmark.py b/test/test_throughput_benchmark.py index e317a840514f..fe838928b8e0 100644 --- a/test/test_throughput_benchmark.py +++ b/test/test_throughput_benchmark.py @@ -79,6 +79,54 @@ class TestThroughputBenchmark(TestCase): with TemporaryFileName() as fname: self.linear_test(TwoLayerNetModule, profiler_output_path=fname) + def linear_with_compile_test(self, Module, dtype): + from contextlib import nullcontext + + from torch._dynamo import config + from torch._inductor import config as inductor_config + + config.error_on_recompile = True + inductor_config.cpp_wrapper = True + inductor_config.freezing = True + D_in = 10 + H = 5 + D_out = 15 + B = 8 + + autocast = dtype != torch.float32 + module = Module(D_in, H, D_out) + + input = (torch.randn(B, D_in), torch.randn(B, D_in)) + + with torch.no_grad(), torch.amp.autocast("cpu", enabled=autocast, dtype=dtype): + torch._dynamo.reset() + module(*input) + module = torch.compile(module) + module(*input) + module(*input) + + ctx = nullcontext() + if dtype == torch.float16 or dtype == torch.bfloat16: + ctx = torch.amp.autocast("cpu", enabled=autocast, dtype=dtype) + with torch.no_grad(), ctx: + bench = ThroughputBenchmark(module) + bench.add_input(*input) + + module_result = module(*input) + bench_result = bench.run_once(*input) + torch.testing.assert_close(bench_result, module_result) + + stats = bench.benchmark( + num_calling_threads=4, num_warmup_iters=100, num_iters=1000 + ) + + print(stats) + + def test_compile(self): + dtypes = [torch.float32, torch.float16, torch.bfloat16] + for dtype in dtypes: + self.linear_with_compile_test(TwoLayerNetModule, dtype) + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 4bea3472b761..b080ffe75c43 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -539,6 +539,10 @@ struct AutocastState { bool operator==(const AutocastState& o) const { for (size_t i = 0; i < DEVICES.size(); i++) { + // If disabled audocast, autocast_dtype comparison not occur + if (enabled[i] == false && o.enabled[i] == false) { + continue; + } if (enabled[i] != o.enabled[i] || dtype[i] != o.dtype[i]) { return false; } diff --git a/torch/csrc/utils/throughput_benchmark-inl.h b/torch/csrc/utils/throughput_benchmark-inl.h index ead63d585a05..f32f15012461 100644 --- a/torch/csrc/utils/throughput_benchmark-inl.h +++ b/torch/csrc/utils/throughput_benchmark-inl.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -61,6 +62,14 @@ BenchmarkExecutionStats BenchmarkHelper::benchmark( callers.reserve(config.num_calling_threads); + static constexpr auto& DEVICES = at::autocast::_AUTOCAST_SUPPORTED_DEVICES; + std::array autocast_enabled; + std::array autocast_dtype; + for (size_t i = 0; i < DEVICES.size(); i++) { + autocast_enabled[i] = at::autocast::is_autocast_enabled(DEVICES[i]); + autocast_dtype[i] = at::autocast::get_autocast_dtype(DEVICES[i]); + } + bool autocast_cache_enabled = at::autocast::is_autocast_cache_enabled(); bool tls_grad_enabled = c10::GradMode::is_enabled(); c10::impl::LocalDispatchKeySet tls_key_set = c10::impl::tls_local_dispatch_key_set(); @@ -71,6 +80,11 @@ BenchmarkExecutionStats BenchmarkHelper::benchmark( // performs required warmeup iterations before we start measuring c10::GradMode::set_enabled(tls_grad_enabled); c10::impl::_force_tls_local_dispatch_key_set(tls_key_set); + for (size_t i = 0; i < DEVICES.size(); i++) { + at::autocast::set_autocast_enabled(DEVICES[i], autocast_enabled[i]); + at::autocast::set_autocast_dtype(DEVICES[i], autocast_dtype[i]); + } + at::autocast::set_autocast_cache_enabled(autocast_cache_enabled); for (const auto j : c10::irange(config.num_warmup_iters)) { (void)j;