mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 14:15:07 +08:00
Problem: the migration of `AT_DISPATCH_V2` macros to headeronly cannot be a simple copy-paste of macro definitions from one header file to another because the macros `AT_DISPATCH_SWITCH` and `AT_DISPATCH_CASE` may use functions that cannot be migrated to headeronly, e.g. when a selective build feature is enabled, there will be functions that are generated. On the other hand, when not using selective build, the dtype-dispatch macros are perfectly suitable for migrating to headeronly.
In this PR, the migration problem above is tackled by refactoring `AT_DISPATCH` related macros into headeronly macros and non-headeronly macros while preserving the current API and semantics. For instance, consider the current V2 macro definitions:
```c++
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
#define AT_AP_VAR(N, T, ...) \
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
...
```
where the headeronly-migration-problematic parts are using AT_DISPATCH_SWITCH and AT_DISPATCH_CASE macros (defined in ATen/Dispatch.h). In this PR, we introduce parametric versions of `AT_DISPATCH_V2` and `AT_AP1` macros that have `_TMPL` suffices, have DISPATCH_SWITCH and DISPATCH_CASE arguments, and are define in `torch/headeronly/core/Dispatch_v2.h`:
```c++
#define THO_DISPATCH_V2_TMPL( \
DISPATCH_SWITCH, DISPATCH_CASE, TYPE, NAME, BODY, ...) \
DISPATCH_SWITCH( \
TYPE, \
NAME, \
THO_AP_VAR_TMPL(DISPATCH_CASE, AT_WRAP(BODY), TYPE, __VA_ARGS__))
#define THO_AP_VAR_TMPL(C, N, T, ...) \
AT_EXPAND( \
AT_CONCAT(THO_AP, AT_NUM_ARGS(__VA_ARGS__))(C, AT_WRAP(N), __VA_ARGS__))
#define THO_AP1(C, N, _1) C(_1, N)
...
```
so that original V2 macro definition, defined in ATen/Dispatch_v2.h, becomes:
```c++
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
THO_DISPATCH_V2_TMPL( \
AT_DISPATCH_SWITCH, \
AT_DISPATCH_CASE, \
TYPE, \
NAME, \
AT_WRAP(BODY), \
__VA_ARGS__)
```
that has exactly the same API and semantics as the original definition.
Note 1: ~we have changed the definition of `AT_AP1(N, _1) ...` to `AT_AP1(C, N, _1) ...` without renaming `AT_AP1` because `AT_AP1` is a helper macro that is not a part of public API (for instance, nothing in pytorch explicitly uses `AT_AP1`).~ UPDATE: restored the original `AT_AP` macros and introduced new `THO_AP` macros.
Note 2: this PR introduces a new API macro THO_DISPATCH_V2_TMPL that will be available for stable ABI users who can use it by providing custom versions of `AT_DISPATCH_SWITCH` and `AT_DISPATCH_CASE macros, say, with selective build features removed.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165856
Approved by: https://github.com/janeyx99
83 lines
3.5 KiB
C++
83 lines
3.5 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/headeronly/core/Dispatch.h>
|
|
#include <torch/headeronly/core/Dispatch_v2.h>
|
|
|
|
// MY_PRIVATE_CHECK_SELECTIVE_BUILD is a prelude to case block. For
|
|
// testing, we do nothing:
|
|
#define MY_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) /* empty */
|
|
|
|
#define MY_PRIVATE_CASE_TYPE_USING_HINT(...) \
|
|
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
|
MY_PRIVATE_CHECK_SELECTIVE_BUILD, __VA_ARGS__)
|
|
|
|
#define MY_DISPATCH_CASE(...) \
|
|
THO_DISPATCH_CASE_TMPL(MY_PRIVATE_CASE_TYPE_USING_HINT, __VA_ARGS__)
|
|
|
|
// MY_RECORD_KERNEL_FUNCTION_DTYPE is a prelude to switch
|
|
// statement. For testing, we just avoid unused variable warning:
|
|
#define MY_RECORD_KERNEL_FUNCTION_DTYPE(DISPATCHNAME, ENUMTYPE) \
|
|
(void)DISPATCHNAME
|
|
|
|
// MY_CHECK_NOT_IMPLEMENTED is called in switch default block. For
|
|
// testing, we count case mismatches:
|
|
#define MY_CHECK_NOT_IMPLEMENTED(...) default_count++
|
|
|
|
#define MY_DISPATCH_SWITCH(...) \
|
|
THO_DISPATCH_SWITCH_TMPL( \
|
|
MY_RECORD_KERNEL_FUNCTION_DTYPE, MY_CHECK_NOT_IMPLEMENTED, __VA_ARGS__)
|
|
|
|
// MY_CASE_FUNCTION is called in a case block. For testing, we count
|
|
// case matches and ensure that scalar_t/index_t type is defined:
|
|
#define MY_CASE_FUNCTION \
|
|
[&] { \
|
|
count++; \
|
|
scalar_t tmp; \
|
|
(void)tmp; \
|
|
}
|
|
#define MY_INDEX_CASE_FUNCTION \
|
|
[&] { \
|
|
count++; \
|
|
index_t tmp; \
|
|
(void)tmp; \
|
|
}
|
|
|
|
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
|
|
|
|
#define MY_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
|
THO_DISPATCH_V2_TMPL( \
|
|
MY_DISPATCH_SWITCH, \
|
|
MY_DISPATCH_CASE, \
|
|
TYPE, \
|
|
NAME, \
|
|
AT_WRAP(BODY), \
|
|
__VA_ARGS__)
|
|
|
|
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
|
|
TEST(TestDispatchV2, NAME) { \
|
|
using torch::headeronly::ScalarType; \
|
|
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
|
int8_t total_count = 0; \
|
|
int8_t count = 0; \
|
|
int8_t default_count = 0; \
|
|
for (ScalarType t : \
|
|
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
|
|
total_count++; \
|
|
MY_DISPATCH_V2(t, "test_my_dispatch_v2", MY_CASE_FUNCTION, __VA_ARGS__); \
|
|
} \
|
|
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
|
EXPECT_EQ(default_count + count, total_count); \
|
|
}
|
|
|
|
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
|
|
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
|
|
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
|
|
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
|
|
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
|
|
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
|
|
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
|
|
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
|
|
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
|
|
|
|
#undef DEFINE_ITEM
|