wengshiy
2025-01-26 03:37:18 +00:00
committed by PyTorch MergeBot
parent cb814c0b96
commit 73622fc5fa
3 changed files with 66 additions and 0 deletions

View File

@ -79,6 +79,54 @@ class TestThroughputBenchmark(TestCase):
with TemporaryFileName() as fname:
self.linear_test(TwoLayerNetModule, profiler_output_path=fname)
def linear_with_compile_test(self, Module, dtype):
from contextlib import nullcontext
from torch._dynamo import config
from torch._inductor import config as inductor_config
config.error_on_recompile = True
inductor_config.cpp_wrapper = True
inductor_config.freezing = True
D_in = 10
H = 5
D_out = 15
B = 8
autocast = dtype != torch.float32
module = Module(D_in, H, D_out)
input = (torch.randn(B, D_in), torch.randn(B, D_in))
with torch.no_grad(), torch.amp.autocast("cpu", enabled=autocast, dtype=dtype):
torch._dynamo.reset()
module(*input)
module = torch.compile(module)
module(*input)
module(*input)
ctx = nullcontext()
if dtype == torch.float16 or dtype == torch.bfloat16:
ctx = torch.amp.autocast("cpu", enabled=autocast, dtype=dtype)
with torch.no_grad(), ctx:
bench = ThroughputBenchmark(module)
bench.add_input(*input)
module_result = module(*input)
bench_result = bench.run_once(*input)
torch.testing.assert_close(bench_result, module_result)
stats = bench.benchmark(
num_calling_threads=4, num_warmup_iters=100, num_iters=1000
)
print(stats)
def test_compile(self):
dtypes = [torch.float32, torch.float16, torch.bfloat16]
for dtype in dtypes:
self.linear_with_compile_test(TwoLayerNetModule, dtype)
if __name__ == "__main__":
run_tests()

View File

@ -539,6 +539,10 @@ struct AutocastState {
bool operator==(const AutocastState& o) const {
for (size_t i = 0; i < DEVICES.size(); i++) {
// If disabled audocast, autocast_dtype comparison not occur
if (enabled[i] == false && o.enabled[i] == false) {
continue;
}
if (enabled[i] != o.enabled[i] || dtype[i] != o.dtype[i]) {
return false;
}

View File

@ -8,6 +8,7 @@
#include <torch/csrc/utils/pybind.h>
#include <ATen/Parallel.h>
#include <ATen/autocast_mode.h>
#include <c10/core/GradMode.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/irange.h>
@ -61,6 +62,14 @@ BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
callers.reserve(config.num_calling_threads);
static constexpr auto& DEVICES = at::autocast::_AUTOCAST_SUPPORTED_DEVICES;
std::array<bool, DEVICES.size()> autocast_enabled;
std::array<at::ScalarType, DEVICES.size()> autocast_dtype;
for (size_t i = 0; i < DEVICES.size(); i++) {
autocast_enabled[i] = at::autocast::is_autocast_enabled(DEVICES[i]);
autocast_dtype[i] = at::autocast::get_autocast_dtype(DEVICES[i]);
}
bool autocast_cache_enabled = at::autocast::is_autocast_cache_enabled();
bool tls_grad_enabled = c10::GradMode::is_enabled();
c10::impl::LocalDispatchKeySet tls_key_set =
c10::impl::tls_local_dispatch_key_set();
@ -71,6 +80,11 @@ BenchmarkExecutionStats BenchmarkHelper<Input, Output, Model>::benchmark(
// performs required warmeup iterations before we start measuring
c10::GradMode::set_enabled(tls_grad_enabled);
c10::impl::_force_tls_local_dispatch_key_set(tls_key_set);
for (size_t i = 0; i < DEVICES.size(); i++) {
at::autocast::set_autocast_enabled(DEVICES[i], autocast_enabled[i]);
at::autocast::set_autocast_dtype(DEVICES[i], autocast_dtype[i]);
}
at::autocast::set_autocast_cache_enabled(autocast_cache_enabled);
for (const auto j : c10::irange(config.num_warmup_iters)) {
(void)j;