mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Fix Throughputbenchmark issue (#144669)
Fixes [144461](https://github.com/pytorch/pytorch/issues/144461) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144669 Approved by: https://github.com/leslie-fang-intel, https://github.com/williamwen42, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
cb814c0b96
commit
73622fc5fa
@ -79,6 +79,54 @@ class TestThroughputBenchmark(TestCase):
|
||||
with TemporaryFileName() as fname:
|
||||
self.linear_test(TwoLayerNetModule, profiler_output_path=fname)
|
||||
|
||||
def linear_with_compile_test(self, Module, dtype):
|
||||
from contextlib import nullcontext
|
||||
|
||||
from torch._dynamo import config
|
||||
from torch._inductor import config as inductor_config
|
||||
|
||||
config.error_on_recompile = True
|
||||
inductor_config.cpp_wrapper = True
|
||||
inductor_config.freezing = True
|
||||
D_in = 10
|
||||
H = 5
|
||||
D_out = 15
|
||||
B = 8
|
||||
|
||||
autocast = dtype != torch.float32
|
||||
module = Module(D_in, H, D_out)
|
||||
|
||||
input = (torch.randn(B, D_in), torch.randn(B, D_in))
|
||||
|
||||
with torch.no_grad(), torch.amp.autocast("cpu", enabled=autocast, dtype=dtype):
|
||||
torch._dynamo.reset()
|
||||
module(*input)
|
||||
module = torch.compile(module)
|
||||
module(*input)
|
||||
module(*input)
|
||||
|
||||
ctx = nullcontext()
|
||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||
ctx = torch.amp.autocast("cpu", enabled=autocast, dtype=dtype)
|
||||
with torch.no_grad(), ctx:
|
||||
bench = ThroughputBenchmark(module)
|
||||
bench.add_input(*input)
|
||||
|
||||
module_result = module(*input)
|
||||
bench_result = bench.run_once(*input)
|
||||
torch.testing.assert_close(bench_result, module_result)
|
||||
|
||||
stats = bench.benchmark(
|
||||
num_calling_threads=4, num_warmup_iters=100, num_iters=1000
|
||||
)
|
||||
|
||||
print(stats)
|
||||
|
||||
def test_compile(self):
|
||||
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||
for dtype in dtypes:
|
||||
self.linear_with_compile_test(TwoLayerNetModule, dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -539,6 +539,10 @@ struct AutocastState {
|
||||
|
||||
bool operator==(const AutocastState& o) const {
|
||||
for (size_t i = 0; i < DEVICES.size(); i++) {
|
||||
// If disabled audocast, autocast_dtype comparison not occur
|
||||
if (enabled[i] == false && o.enabled[i] == false) {
|
||||
continue;
|
||||
}
|
||||
if (enabled[i] != o.enabled[i] || dtype[i] != o.dtype[i]) {
|
||||
return false;
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/util/irange.h>
|
||||
@ -61,6 +62,14 @@ BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
|
||||
|
||||
callers.reserve(config.num_calling_threads);
|
||||
|
||||
static constexpr auto& DEVICES = at::autocast::_AUTOCAST_SUPPORTED_DEVICES;
|
||||
std::array<bool, DEVICES.size()> autocast_enabled;
|
||||
std::array<at::ScalarType, DEVICES.size()> autocast_dtype;
|
||||
for (size_t i = 0; i < DEVICES.size(); i++) {
|
||||
autocast_enabled[i] = at::autocast::is_autocast_enabled(DEVICES[i]);
|
||||
autocast_dtype[i] = at::autocast::get_autocast_dtype(DEVICES[i]);
|
||||
}
|
||||
bool autocast_cache_enabled = at::autocast::is_autocast_cache_enabled();
|
||||
bool tls_grad_enabled = c10::GradMode::is_enabled();
|
||||
c10::impl::LocalDispatchKeySet tls_key_set =
|
||||
c10::impl::tls_local_dispatch_key_set();
|
||||
@ -71,6 +80,11 @@ BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
|
||||
// performs required warmeup iterations before we start measuring
|
||||
c10::GradMode::set_enabled(tls_grad_enabled);
|
||||
c10::impl::_force_tls_local_dispatch_key_set(tls_key_set);
|
||||
for (size_t i = 0; i < DEVICES.size(); i++) {
|
||||
at::autocast::set_autocast_enabled(DEVICES[i], autocast_enabled[i]);
|
||||
at::autocast::set_autocast_dtype(DEVICES[i], autocast_dtype[i]);
|
||||
}
|
||||
at::autocast::set_autocast_cache_enabled(autocast_cache_enabled);
|
||||
|
||||
for (const auto j : c10::irange(config.num_warmup_iters)) {
|
||||
(void)j;
|
||||
|
Reference in New Issue
Block a user