mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix Throughputbenchmark issue (#144669)
Fixes [144461](https://github.com/pytorch/pytorch/issues/144461) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144669 Approved by: https://github.com/leslie-fang-intel, https://github.com/williamwen42, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
cb814c0b96
commit
73622fc5fa
@ -79,6 +79,54 @@ class TestThroughputBenchmark(TestCase):
|
|||||||
with TemporaryFileName() as fname:
|
with TemporaryFileName() as fname:
|
||||||
self.linear_test(TwoLayerNetModule, profiler_output_path=fname)
|
self.linear_test(TwoLayerNetModule, profiler_output_path=fname)
|
||||||
|
|
||||||
|
def linear_with_compile_test(self, Module, dtype):
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
from torch._dynamo import config
|
||||||
|
from torch._inductor import config as inductor_config
|
||||||
|
|
||||||
|
config.error_on_recompile = True
|
||||||
|
inductor_config.cpp_wrapper = True
|
||||||
|
inductor_config.freezing = True
|
||||||
|
D_in = 10
|
||||||
|
H = 5
|
||||||
|
D_out = 15
|
||||||
|
B = 8
|
||||||
|
|
||||||
|
autocast = dtype != torch.float32
|
||||||
|
module = Module(D_in, H, D_out)
|
||||||
|
|
||||||
|
input = (torch.randn(B, D_in), torch.randn(B, D_in))
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.amp.autocast("cpu", enabled=autocast, dtype=dtype):
|
||||||
|
torch._dynamo.reset()
|
||||||
|
module(*input)
|
||||||
|
module = torch.compile(module)
|
||||||
|
module(*input)
|
||||||
|
module(*input)
|
||||||
|
|
||||||
|
ctx = nullcontext()
|
||||||
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||||
|
ctx = torch.amp.autocast("cpu", enabled=autocast, dtype=dtype)
|
||||||
|
with torch.no_grad(), ctx:
|
||||||
|
bench = ThroughputBenchmark(module)
|
||||||
|
bench.add_input(*input)
|
||||||
|
|
||||||
|
module_result = module(*input)
|
||||||
|
bench_result = bench.run_once(*input)
|
||||||
|
torch.testing.assert_close(bench_result, module_result)
|
||||||
|
|
||||||
|
stats = bench.benchmark(
|
||||||
|
num_calling_threads=4, num_warmup_iters=100, num_iters=1000
|
||||||
|
)
|
||||||
|
|
||||||
|
print(stats)
|
||||||
|
|
||||||
|
def test_compile(self):
|
||||||
|
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
||||||
|
for dtype in dtypes:
|
||||||
|
self.linear_with_compile_test(TwoLayerNetModule, dtype)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -539,6 +539,10 @@ struct AutocastState {
|
|||||||
|
|
||||||
bool operator==(const AutocastState& o) const {
|
bool operator==(const AutocastState& o) const {
|
||||||
for (size_t i = 0; i < DEVICES.size(); i++) {
|
for (size_t i = 0; i < DEVICES.size(); i++) {
|
||||||
|
// If disabled audocast, autocast_dtype comparison not occur
|
||||||
|
if (enabled[i] == false && o.enabled[i] == false) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (enabled[i] != o.enabled[i] || dtype[i] != o.dtype[i]) {
|
if (enabled[i] != o.enabled[i] || dtype[i] != o.dtype[i]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@
|
|||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
|
#include <ATen/autocast_mode.h>
|
||||||
#include <c10/core/GradMode.h>
|
#include <c10/core/GradMode.h>
|
||||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
@ -61,6 +62,14 @@ BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
|
|||||||
|
|
||||||
callers.reserve(config.num_calling_threads);
|
callers.reserve(config.num_calling_threads);
|
||||||
|
|
||||||
|
static constexpr auto& DEVICES = at::autocast::_AUTOCAST_SUPPORTED_DEVICES;
|
||||||
|
std::array<bool, DEVICES.size()> autocast_enabled;
|
||||||
|
std::array<at::ScalarType, DEVICES.size()> autocast_dtype;
|
||||||
|
for (size_t i = 0; i < DEVICES.size(); i++) {
|
||||||
|
autocast_enabled[i] = at::autocast::is_autocast_enabled(DEVICES[i]);
|
||||||
|
autocast_dtype[i] = at::autocast::get_autocast_dtype(DEVICES[i]);
|
||||||
|
}
|
||||||
|
bool autocast_cache_enabled = at::autocast::is_autocast_cache_enabled();
|
||||||
bool tls_grad_enabled = c10::GradMode::is_enabled();
|
bool tls_grad_enabled = c10::GradMode::is_enabled();
|
||||||
c10::impl::LocalDispatchKeySet tls_key_set =
|
c10::impl::LocalDispatchKeySet tls_key_set =
|
||||||
c10::impl::tls_local_dispatch_key_set();
|
c10::impl::tls_local_dispatch_key_set();
|
||||||
@ -71,6 +80,11 @@ BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
|
|||||||
// performs required warmeup iterations before we start measuring
|
// performs required warmeup iterations before we start measuring
|
||||||
c10::GradMode::set_enabled(tls_grad_enabled);
|
c10::GradMode::set_enabled(tls_grad_enabled);
|
||||||
c10::impl::_force_tls_local_dispatch_key_set(tls_key_set);
|
c10::impl::_force_tls_local_dispatch_key_set(tls_key_set);
|
||||||
|
for (size_t i = 0; i < DEVICES.size(); i++) {
|
||||||
|
at::autocast::set_autocast_enabled(DEVICES[i], autocast_enabled[i]);
|
||||||
|
at::autocast::set_autocast_dtype(DEVICES[i], autocast_dtype[i]);
|
||||||
|
}
|
||||||
|
at::autocast::set_autocast_cache_enabled(autocast_cache_enabled);
|
||||||
|
|
||||||
for (const auto j : c10::irange(config.num_warmup_iters)) {
|
for (const auto j : c10::irange(config.num_warmup_iters)) {
|
||||||
(void)j;
|
(void)j;
|
||||||
|
Reference in New Issue
Block a user