#include #include #include #include #include #include #include #include #include #include #include #include #include // Test that we can add and remove callbacks (both global and thread local.) TEST(RecordFunctionTest, AddRemove) { at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); auto start_callback = [](const at::RecordFunction& fn) -> std::unique_ptr { return nullptr; }; auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {}; auto handle = at::addThreadLocalCallback( at::RecordFunctionCallback(start_callback, end_callback)); ASSERT_TRUE(at::hasCallbacks()); ASSERT_TRUE(at::hasThreadLocalCallbacks()); ASSERT_FALSE(at::hasGlobalCallbacks()); at::removeCallback(handle); ASSERT_FALSE(at::hasCallbacks()); handle = at::addGlobalCallback( at::RecordFunctionCallback(start_callback, end_callback)); ASSERT_TRUE(at::hasCallbacks()); ASSERT_FALSE(at::hasThreadLocalCallbacks()); ASSERT_TRUE(at::hasGlobalCallbacks()); at::removeCallback(handle); ASSERT_FALSE(at::hasCallbacks()); } // Test that the callbacks that we register are actually run. TEST(RecordFunctionTest, ThreadLocalState) { at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); static int tls_test_start_counter; static int tls_test_end_counter; tls_test_start_counter = 0; tls_test_end_counter = 0; auto start_callback = [](const at::RecordFunction&) -> std::unique_ptr { ++tls_test_start_counter; return nullptr; }; auto end_callback = [](const at::RecordFunction&, at::ObserverContext*) { ++tls_test_end_counter; }; auto handle = at::addThreadLocalCallback( at::RecordFunctionCallback(start_callback, end_callback)); { at::RecordFunction guard(at::RecordScope::USER_SCOPE); guard.before("Test"); EXPECT_EQ(tls_test_start_counter, 1); EXPECT_EQ(tls_test_end_counter, 0); } EXPECT_EQ(tls_test_start_counter, 1); EXPECT_EQ(tls_test_end_counter, 1); { tls_test_start_counter = 0; tls_test_end_counter = 0; at::DisableRecordFunctionGuard no_profile_guard; at::RecordFunction guard(at::RecordScope::USER_SCOPE); guard.before("Test"); EXPECT_EQ(tls_test_start_counter, 0); EXPECT_EQ(tls_test_end_counter, 0); } EXPECT_EQ(tls_test_start_counter, 0); EXPECT_EQ(tls_test_end_counter, 0); { tls_test_start_counter = 0; tls_test_end_counter = 0; RECORD_FUNCTION("Test", {}); EXPECT_EQ(tls_test_start_counter, 1); EXPECT_EQ(tls_test_end_counter, 0); } EXPECT_EQ(tls_test_start_counter, 1); EXPECT_EQ(tls_test_end_counter, 1); at::removeCallback(handle); ASSERT_FALSE(at::hasCallbacks()); } // Test that callbacks are run in the order that they are registered. TEST(RecordFunctionTest, CallOrder) { at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); static int current_index; current_index = 0; static std::array expected_order = { "Start Callback 0 Outer", "Start Callback 1 Outer", "Start Callback 0 Inner", "Start Callback 1 Inner", "End Callback 0 Inner", "End Callback 1 Inner", "End Callback 0 Outer", "End Callback 1 Outer", }; #define REGISTER_CALLBACK(index) \ at::addThreadLocalCallback( \ at::RecordFunctionCallback( \ [](const at::RecordFunction& fn) \ -> std::unique_ptr { \ EXPECT_EQ( \ fmt::format("Start Callback {} {}", index, fn.name()), \ expected_order[current_index++]); \ return nullptr; \ }, \ [](const at::RecordFunction& fn, at::ObserverContext*) { \ EXPECT_EQ( \ fmt::format("End Callback {} {}", index, fn.name()), \ expected_order[current_index++]); \ }) \ .scopes({at::RecordScope::FUNCTION})) REGISTER_CALLBACK(0); REGISTER_CALLBACK(1); #undef REGISTER_CALLBACK RECORD_FUNCTION("Outer", {}); { RECORD_FUNCTION("Inner", {}); } at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); } // Make sure TLS migrates when tasks are launched. TEST(RecordFunctionTest, ThreadMigration) { at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); static int call_count; call_count = 0; auto handle = at::addThreadLocalCallback( at::RecordFunctionCallback( [](const at::RecordFunction&) -> std::unique_ptr { return nullptr; }, [](const at::RecordFunction&, at::ObserverContext*) { ++call_count; }) .scopes({at::RecordScope::FUNCTION})); EXPECT_EQ(call_count, 0); std::condition_variable cv; std::mutex lock; at::launch([&cv]() { RECORD_FUNCTION("Test", {}); cv.notify_all(); }); auto guard = std::unique_lock(lock); cv.wait(guard, [] { return call_count > 0; }); EXPECT_EQ(call_count, 1); at::removeCallback(handle); ASSERT_FALSE(at::hasCallbacks()); } // Test sampling logic and validate that callbacks fire at the correct times. TEST(RecordFunctionTest, Sampling) { at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); static int sample_test_counter; sample_test_counter = 0; uint32_t seed = 12345; double p = 0.25; at::set_record_function_seed_for_testing(seed); std::mt19937 generator; generator.seed(seed); auto dist = std::geometric_distribution(p); // Make sure we know which steps should fire. auto outcomes = std::array{7, 0, 0, 6, 2}; for (const auto i : c10::irange(outcomes.size())) { ASSERT_EQ(dist(generator), outcomes[i]); } std::vector expected_counts; int running_count = 0; for (const auto i : c10::irange(outcomes.size())) { for ([[maybe_unused]] const auto j : c10::irange(outcomes[i])) { expected_counts.push_back(running_count); } expected_counts.push_back(++running_count); } auto start_callback = [](const at::RecordFunction& fn) -> std::unique_ptr { ++sample_test_counter; return nullptr; }; auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {}; auto handle = at::addThreadLocalCallback( at::RecordFunctionCallback(start_callback, end_callback) .samplingProb(p) .scopes({at::RecordScope::FUNCTION})); for (const auto i : c10::irange(expected_counts.size())) { RECORD_FUNCTION("Test", {}); EXPECT_EQ(sample_test_counter, expected_counts[i]); } at::removeCallback(handle); ASSERT_FALSE(at::hasCallbacks()); } // Validate sampling against a simple reference implementation for a complex set // of registered callbacks. TEST(RecordFunctionTest, MultipleCallbacks) { at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); uint32_t seed = 54321; std::mt19937 generator; generator.seed(seed); auto sample = [&](double p) { return (p < 1.0 ? std::geometric_distribution(p)(generator) : 0) + 1; }; std::array probabilities{0.1, 1.0, 1.0, 0.3}; std::array next_call; std::array counts; static std::array counts_from_rec_fn; counts_from_rec_fn.fill(0); auto end_callback = [](const at::RecordFunction& fn, at::ObserverContext*) {}; #define REGISTER_CALLBACK(register_fn, index) \ register_fn(at::RecordFunctionCallback( \ [](const at::RecordFunction& fn) \ -> std::unique_ptr { \ ++counts_from_rec_fn[index]; \ return nullptr; \ }, \ end_callback) \ .samplingProb(probabilities[index]) \ .scopes({at::RecordScope::FUNCTION})) REGISTER_CALLBACK(at::addGlobalCallback, 0); REGISTER_CALLBACK(at::addGlobalCallback, 1); REGISTER_CALLBACK(at::addThreadLocalCallback, 2); // The RecordFunction machinery will rebuild callbacks whenever a new observer // is registered, so we need to wait until the last callback to seed the // random number generator. at::set_record_function_seed_for_testing(seed); REGISTER_CALLBACK(at::addThreadLocalCallback, 3); #undef REGISTER_CALLBACK for (const auto i : c10::irange(probabilities.size())) { next_call[i] = sample(probabilities[i]); } for ([[maybe_unused]] const auto i : c10::irange(50)) { RECORD_FUNCTION("Test", {}); for (const auto j : c10::irange(next_call.size())) { if (!(--next_call[j])) { ++counts[j]; next_call[j] = sample(probabilities[j]); } EXPECT_EQ(counts[j], counts_from_rec_fn[j]); } } at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); } // Test that KwargsOnly callbacks are run in USER_SCOPE. TEST(RecordFunctionTest, KwargsOnly) { at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); static const std::unordered_map myMap = { {"a", 1}, {"b", 2.5}}; #define REGISTER_CALLBACK() \ at::addThreadLocalCallback( \ at::RecordFunctionCallback( \ [](const at::RecordFunction& fn) \ -> std::unique_ptr { \ EXPECT_EQ(myMap, fn.kwinputs()); \ return nullptr; \ }, \ [](const at::RecordFunction& fn, at::ObserverContext*) {}) \ .needsInputs(true) \ .scopes({at::RecordScope::USER_SCOPE})) REGISTER_CALLBACK(); #undef REGISTER_CALLBACK RECORD_USER_SCOPE_WITH_KWARGS_ONLY("Test", &myMap); at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); }