enable torch.cpu.amp.autocast (#57386)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57386

Here is the PR for what's discussed in the RFC https://github.com/pytorch/pytorch/issues/55374 to enable the autocast for CPU device. Currently, this PR only enable BF16 as the lower precision datatype.

Changes:
1.  Enable new API `torch.cpu.amp.autocast` for autocast on CPU device: include the python API, C++ API, new Dispatchkey etc.
2.  Consolidate the implementation for each cast policy sharing between CPU and GPU devices.
3.  Add the operation lists to corresponding cast policy for cpu autocast.

Test Plan: Imported from OSS

Reviewed By: soulitzer

Differential Revision: D28572219

Pulled By: ezyang

fbshipit-source-id: db3db509973b16a5728ee510b5e1ee716b03a152
This commit is contained in:
leslie-fang-intel
2021-05-20 17:45:18 -07:00
committed by Facebook GitHub Bot
parent b6dcdeacc9
commit 0ede83db7a
15 changed files with 630 additions and 106 deletions

View File

@ -13,19 +13,24 @@ namespace at {
namespace autocast {
bool is_enabled() {
//return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA) ||
// !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCPU);
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA);
}
void set_enabled(bool new_enabled) {
//c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU, !new_enabled);
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled);
}
bool is_cpu_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCPU);
}
void set_cpu_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU, !new_enabled);
}
namespace {
// Imitate Apex and cache some of the casts to streamline parameter reuse.
// Our heuristic is to cache fp16 casts of fp32 model weights (see cached_cast below).
// Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below).
//
// After discussion with @ezyang, the cache uses the following structure:
// The key is the fp32 source tensor's TensorImpl*, a proxy for a Tensor uuid that's
@ -51,6 +56,9 @@ thread_local std::unordered_map<TensorImpl*, val_type> cached_casts;
// it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
thread_local int nesting = 0;
// autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU.
thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16;
}
void clear_cache() {
@ -65,15 +73,28 @@ int decrement_nesting() {
return --nesting;
}
at::ScalarType get_autocast_cpu_dtype() {
return autocast_cpu_dtype;
}
void set_autocast_cpu_dtype(at::ScalarType dtype) {
TORCH_CHECK(
dtype == at::kBFloat16,
"Currently, AutocastCPU only support Bfloat16 as the autocast_cpu_dtype");
autocast_cpu_dtype = dtype;
}
// Overload to catch Tensor args
// TODO (possible optimization):
// Move cast_cache to an inline function in a header with cached_casts declared as
// extern thread_local in the header.
Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
if (is_eligible(arg) && (arg.scalar_type() != to_type)) {
// Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves).
Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_type) {
if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) {
// Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves).
// See cached_casts declaration above for detailed strategy.
bool can_try_cache = (to_type == at::kHalf && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf() && !arg.is_view());
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
arg.is_leaf() && !arg.is_view());
if (can_try_cache) {
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
if (it != cached_casts.end()) {
@ -94,7 +115,8 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
// Policies correspond to op categories that need code-divergent handling.
// Wrapper templates below are specialized based on a policy template parameter.
enum class CastPolicy : uint8_t {
fp16 = 0, // Cast all inputs to at::kHalf before running the op.
lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before running the op.
// Currently, lower_precision_fp is fp16 for AutocastCUDA, and is defined by user(default bf16) for AutocastCPU.
fp32, // Cast all inputs to at::kFloat before running the op.
fp32_set_opt_dtype, // Treats functions (like softmax) that
// 1. we'd like to run in fp32 and
@ -122,29 +144,29 @@ Interior WrapFunction_ specializations are defined for each CastPolicy.
********************************************************************************************************/
// Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy
template<CastPolicy policy, class Redispatch, Redispatch* F, class Ret, class ArgList> struct WrapFunction_ {};
template<CastPolicy policy, DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class ArgList> struct WrapFunction_ {};
// CastPolicy::fp16
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
// CastPolicy::lower_precision_fp General_DeviceType
template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::lower_precision_fp, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
return (*F)(cached_cast(at::kHalf, args)...);
c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type));
return (*F)(cached_cast(get_lower_precision_fp_from_device_type(device_type), args, device_type)...);
}
};
// CastPolicy::fp32
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
// CastPolicy::fp32 General_DeviceType
template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
return (*F)(cached_cast(at::kFloat, args)...);
c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type));
return (*F)(cached_cast(at::kFloat, args, device_type)...);
}
};
// CastPolicy::fp32_set_opt_dtype
// CastPolicy::fp32_set_opt_dtype DeviceType::CUDA
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, DeviceType::CUDA, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
if (firstarg_is_eligible(args...)) {
@ -157,9 +179,9 @@ struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::t
}
};
// CastPolicy::fp32_append_dtype
// CastPolicy::fp32_append_dtype DeviceType::CUDA
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
struct WrapFunction_<CastPolicy::fp32_append_dtype, DeviceType::CUDA, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
at::ScalarType out_type = type_from_firstarg(at::kFloat, args...);
@ -167,18 +189,19 @@ struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::ty
}
};
// CastPolicy::promote
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::promote, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
// CastPolicy::promote General_DeviceType
template<DeviceType device_type, class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::promote, device_type, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
auto to_type = promote_type(at::kHalf, args...);
return (*F)(cached_cast(to_type, args)...);
c10::impl::ExcludeDispatchKeyGuard no_autocast(get_autocast_dispatch_key_from_device_type(device_type));
auto to_type = promote_type(get_lower_precision_fp_from_device_type(device_type), device_type, args...);
return (*F)(cached_cast(to_type, args, device_type)...);
}
};
// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating core/boxing/impl/WrapFunctionIntoFunctor.h)
template<CastPolicy policy,
DeviceType device_type,
class Registered, // The signature for which we're registering. The dispatcher's calling code invokes our
// registered functions with arguments matching Registered, so we register
// WrapFunction_::call methods with a matching signature to properly field those arguments.
@ -190,6 +213,7 @@ template<CastPolicy policy,
Redispatch* F> // The actual function we're redispatching to.
struct WrapFunction final {
using type = WrapFunction_<policy,
device_type,
Redispatch,
F,
typename guts::function_traits<Registered>::return_type,
@ -213,14 +237,15 @@ namespace {
This section performs load-time registration for autocast wrappers.
It's debatable at what level operations should be patched. We'd like casts to be autograd-exposed
and precede autograd history recording, so that for fp16 ops, input tensors are saved for backward
in fp16 rather than fp32. Saving inputs in fp16 can significantly reduce a model's memory footprint.
and precede autograd history recording, so that for lower_precision_fp ops, input tensors are saved for backward
in lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp can significantly reduce
a model's memory footprint.
Option 1 (strawman): Patch only at the level of explicit calls into cudnn/cublas (cudnn_convolution, etc),
because those are the code paths that are guaranteed to use Tensor Cores, therefore they're the ones that
will benefit most from fp16. Potential pitfall: convolutions (and other ops) are wrapped in several
will benefit most from lower_precision_fp. Potential pitfall: convolutions (and other ops) are wrapped in several
layers of at::* calls. If one of those happens to record autograd history, then we've lost the
opportunity to save inputs in fp16.
opportunity to save inputs in lower_precision_fp.
Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd history
recording can't sneak in ahead of autocast. This mirrors Apex most closely.
@ -242,12 +267,17 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
#define KERNEL(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call);
&WrapFunction<CastPolicy::POLICY, DeviceType::CUDA, SIGNATURE, SIGNATURE, &FUNC>::type::call);
// Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction<CastPolicy::POLICY, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call);
&WrapFunction<CastPolicy::POLICY, DeviceType::CUDA, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call);
// KERNEL_CPU registration for AutocastCPU
#define KERNEL_CPU(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction<CastPolicy::POLICY, DeviceType::CPU, SIGNATURE, SIGNATURE, &FUNC>::type::call);
/*****************************************
Explicit registration for out-of-place ops
@ -257,65 +287,65 @@ TORCH_LIBRARY_IMPL(_, Autocast, m) {
}
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
// fp16
KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16)
KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), fp16)
KERNEL(ADD_NS(_convolution_nogroup), "_convolution_nogroup", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef), fp16)
KERNEL(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16)
KERNEL(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16)
KERNEL(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), fp16)
KERNEL(ADD_NS(conv_tbc), "conv_tbc", Tensor (const Tensor &, const Tensor &, const Tensor &, int64_t), fp16)
KERNEL(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp16)
KERNEL(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp16)
KERNEL(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp16)
KERNEL(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp16)
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), fp16)
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), fp16)
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), fp16)
KERNEL(ADD_NS(prelu), "prelu", Tensor (const Tensor &, const Tensor &), fp16)
KERNEL(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16)
KERNEL(ADD_NS(addmv), "addmv", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16)
KERNEL(ADD_NS(addr), "addr", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16)
KERNEL(ADD_NS(matmul), "matmul", Tensor (const Tensor &, const Tensor &), fp16)
KERNEL(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), fp16)
KERNEL(ADD_NS(mv), "mv", Tensor (const Tensor &, const Tensor &), fp16)
KERNEL(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&), fp16)
KERNEL(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16)
KERNEL(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), fp16)
KERNEL(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), fp16)
KERNEL(ADD_NS(chain_matmul), "chain_matmul", Tensor (TensorList), fp16)
KERNEL(ADD_NS(linalg_multi_dot), "linalg_multi_dot", Tensor (TensorList), fp16)
// lower_precision_fp
KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp)
KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), lower_precision_fp)
KERNEL(ADD_NS(_convolution_nogroup), "_convolution_nogroup", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef), lower_precision_fp)
KERNEL(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
KERNEL(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
KERNEL(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
KERNEL(ADD_NS(conv_tbc), "conv_tbc", Tensor (const Tensor &, const Tensor &, const Tensor &, int64_t), lower_precision_fp)
KERNEL(ADD_NS(conv_transpose1d), "conv_transpose1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), lower_precision_fp)
KERNEL(ADD_NS(conv_transpose2d), "conv_transpose2d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), lower_precision_fp)
KERNEL(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), lower_precision_fp)
KERNEL(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), lower_precision_fp)
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), lower_precision_fp)
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), lower_precision_fp)
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), lower_precision_fp)
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose.deprecated2", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool), lower_precision_fp)
KERNEL(ADD_NS(cudnn_convolution), "cudnn_convolution", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp)
KERNEL(ADD_NS(cudnn_convolution_transpose), "cudnn_convolution_transpose", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool, bool, bool), lower_precision_fp)
KERNEL(ADD_NS(prelu), "prelu", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
KERNEL(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL(ADD_NS(addmv), "addmv", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL(ADD_NS(addr), "addr", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL(ADD_NS(matmul), "matmul", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
KERNEL(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
KERNEL(ADD_NS(mv), "mv", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
KERNEL(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&), lower_precision_fp)
KERNEL(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
KERNEL(ADD_NS(chain_matmul), "chain_matmul", Tensor (TensorList), lower_precision_fp)
KERNEL(ADD_NS(linalg_multi_dot), "linalg_multi_dot", Tensor (TensorList), lower_precision_fp)
// The macro doesn't like these (I think it chokes on commas inside <>) so write them manually
m.impl(TORCH_SELECTIVE_NAME("aten::_thnn_fused_lstm_cell"),
TORCH_FN((&WrapFunction<CastPolicy::fp16,
TORCH_FN((&WrapFunction<CastPolicy::lower_precision_fp, DeviceType::CUDA,
std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
&ADD_NS(_thnn_fused_lstm_cell)>::type::call)));
m.impl("_thnn_fused_gru_cell",
TORCH_FN((&WrapFunction<CastPolicy::fp16,
TORCH_FN((&WrapFunction<CastPolicy::lower_precision_fp, DeviceType::CUDA,
std::tuple<Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
std::tuple<Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
&ADD_NS(_thnn_fused_gru_cell)>::type::call)));
m.impl("lstm_cell",
TORCH_FN((&WrapFunction<CastPolicy::fp16,
TORCH_FN((&WrapFunction<CastPolicy::lower_precision_fp, DeviceType::CUDA,
std::tuple<Tensor,Tensor> (const Tensor &, TensorList, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
std::tuple<Tensor,Tensor> (const Tensor &, TensorList, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
&ADD_NS(lstm_cell)>::type::call)));
m.impl("gru_cell",
TORCH_FN((&WrapFunction<CastPolicy::fp16,
TORCH_FN((&WrapFunction<CastPolicy::lower_precision_fp, DeviceType::CUDA,
Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
&ADD_NS(gru_cell)>::type::call)));
m.impl("rnn_tanh_cell", // tanh unary op is executed as a cuda math library call.
TORCH_FN((&WrapFunction<CastPolicy::fp16,
TORCH_FN((&WrapFunction<CastPolicy::lower_precision_fp, DeviceType::CUDA,
Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
&ADD_NS(rnn_tanh_cell)>::type::call)));
m.impl("rnn_relu_cell",
TORCH_FN((&WrapFunction<CastPolicy::fp16,
TORCH_FN((&WrapFunction<CastPolicy::lower_precision_fp, DeviceType::CUDA,
Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
Tensor (const Tensor &, const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&),
&ADD_NS(rnn_relu_cell)>::type::call)));
@ -342,7 +372,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(layer_norm), "layer_norm", Tensor (const Tensor &, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double, bool), fp32)
// The macro doesn't like this one (I think it chokes on commas inside <>) so write it manually
m.impl(TORCH_SELECTIVE_NAME("aten::native_layer_norm"),
TORCH_FN((&WrapFunction<CastPolicy::fp32,
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CUDA,
std::tuple<Tensor,Tensor,Tensor> (const Tensor&, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double),
std::tuple<Tensor,Tensor,Tensor> (const Tensor&, IntArrayRef, const c10::optional<Tensor>&, const c10::optional<Tensor>&, double),
&ADD_NS(native_layer_norm)>::type::call)));
@ -419,7 +449,86 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
TORCH_FN((&at::autocast::binary_cross_entropy_banned)));
}
TORCH_LIBRARY_IMPL(_, AutocastCPU, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
// lower_precision_fp cast policy
KERNEL_CPU(ADD_NS(conv1d), "conv1d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
KERNEL_CPU(ADD_NS(conv2d), "conv2d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
KERNEL_CPU(ADD_NS(conv3d), "conv3d", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, int64_t), lower_precision_fp)
KERNEL_CPU(ADD_NS(_log_softmax), "_log_softmax", Tensor (const Tensor &, int64_t, bool), lower_precision_fp)
KERNEL_CPU(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
KERNEL_CPU(ADD_NS(mm), "mm", Tensor (const Tensor &, const Tensor &), lower_precision_fp)
KERNEL_CPU(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL_CPU(ADD_NS(addmm), "addmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL_CPU(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, const Scalar&, const Scalar&), lower_precision_fp)
KERNEL_CPU(ADD_NS(linear), "linear", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &), lower_precision_fp)
// fp32 cast policy
KERNEL_CPU(ADD_NS(conv_transpose3d), "conv_transpose3d.input", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor> &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(batch_norm), "batch_norm", Tensor (const Tensor &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, const c10::optional<Tensor> &, bool, double, double, bool), fp32)
KERNEL_CPU(ADD_NS(max_pool2d), "max_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool), fp32)
KERNEL_CPU(ADD_NS(adaptive_avg_pool2d), "adaptive_avg_pool2d", Tensor (const Tensor &, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(convolution), "convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t), fp32)
KERNEL_CPU(ADD_NS(dropout), "dropout", Tensor (const Tensor &, double, bool), fp32)
KERNEL_CPU(ADD_NS(avg_pool2d), "avg_pool2d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(avg_pool3d), "avg_pool3d", Tensor (const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, bool, bool, c10::optional<int64_t>), fp32)
KERNEL_CPU(ADD_NS(gelu), "gelu", Tensor (const Tensor &), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest1d), "upsample_nearest1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest2d), "upsample_nearest2d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest2d), "upsample_nearest2d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest3d), "upsample_nearest3d", Tensor (const Tensor &, IntArrayRef, c10::optional<double>, c10::optional<double>, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_nearest3d), "upsample_nearest3d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(upsample_linear1d), "upsample_linear1d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_linear1d), "upsample_linear1d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, bool, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<double>, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_bilinear2d), "upsample_bilinear2d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, bool, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<double>, c10::optional<double>, c10::optional<double>), fp32)
KERNEL_CPU(ADD_NS(upsample_trilinear3d), "upsample_trilinear3d.vec", Tensor (const Tensor &, c10::optional<IntArrayRef>, bool, c10::optional<ArrayRef<double>>), fp32)
KERNEL_CPU(ADD_NS(binary_cross_entropy), "binary_cross_entropy", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, int64_t), fp32)
KERNEL_CPU(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t), fp32)
KERNEL_CPU(ADD_NS(pow), "pow.Tensor_Scalar", Tensor (const Tensor &, const Scalar &), fp32)
KERNEL_CPU(ADD_NS(pow), "pow.Tensor_Tensor", Tensor (const Tensor &, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(pow), "pow.Scalar", Tensor (const Scalar&, const Tensor &), fp32)
KERNEL_CPU(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32)
KERNEL_CPU(ADD_NS(reflection_pad1d), "reflection_pad1d", Tensor (const Tensor &, IntArrayRef), fp32)
KERNEL_CPU(ADD_NS(std), "std", Tensor (const Tensor &, bool), fp32)
KERNEL_CPU(ADD_NS(std), "std.dim", Tensor (const Tensor &, IntArrayRef, bool, bool), fp32)
KERNEL_CPU(ADD_NS(instance_norm), "instance_norm", Tensor (const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, const c10::optional<Tensor>&, const c10::optional<Tensor>&, bool, double, double, bool), fp32)
KERNEL_CPU(ADD_NS(fake_quantize_per_tensor_affine), "fake_quantize_per_tensor_affine", Tensor (const Tensor &, double, int64_t, int64_t, int64_t), fp32)
// promote
KERNEL_CPU(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote)
KERNEL_CPU(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
m.impl(TORCH_SELECTIVE_NAME("aten::topk"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool, bool),
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool, bool),
&ADD_NS(topk)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::sort"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, bool),
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, bool),
&ADD_NS(sort)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool),
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool),
&ADD_NS(kthvalue)>::type::call)));
m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue.dimname"),
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, at::Dimname, bool),
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, at::Dimname, bool),
&ADD_NS(kthvalue)>::type::call)));
}
} // namespace
} // namespace autocast
} // namespace at

View File

@ -3,17 +3,49 @@
namespace at {
namespace autocast {
namespace {
bool is_autocast_eligible(const Tensor& tensor) {
return (tensor.is_cuda() || tensor.is_xla()) && tensor.is_floating_point();
}
} // namespace
TORCH_API bool is_enabled();
TORCH_API void set_enabled(bool enabled);
TORCH_API void clear_cache();
TORCH_API int increment_nesting();
TORCH_API int decrement_nesting();
TORCH_API bool is_cpu_enabled();
TORCH_API void set_cpu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_cpu_dtype();
TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
namespace {
bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
return device_type == DeviceType::CUDA
? (tensor.is_cuda() || tensor.is_xla()) && tensor.is_floating_point()
: (tensor.is_cpu() || tensor.is_mkldnn()) && tensor.is_floating_point();
}
} // namespace
inline DispatchKey get_autocast_dispatch_key_from_device_type(
DeviceType device_type) {
switch (device_type) {
case DeviceType::CUDA:
return DispatchKey::Autocast;
case DeviceType::CPU:
return DispatchKey::AutocastCPU;
default:
throw std::runtime_error(
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
}
}
inline at::ScalarType get_lower_precision_fp_from_device_type(
DeviceType device_type) {
switch (device_type) {
case DeviceType::CUDA:
return at::kHalf;
case DeviceType::CPU:
return get_autocast_cpu_dtype();
default:
throw std::runtime_error(
"unknown device type for autocast in get_lower_precision_fp_from_device_type");
}
}
/********************************************************************
Logic to extract the promote type from any Tensor or TensorList args.
@ -22,19 +54,24 @@ Logic to extract the promote type from any Tensor or TensorList args.
// Overload to catch Tensor args.
// If nextArg is floating-point, compare its scalar_type with our
// current best guess for the promote type, and update if necessary.
inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg) {
inline at::ScalarType prioritize(
at::ScalarType current,
const Tensor& nextArg,
DeviceType device_type=DeviceType::CUDA) {
if (current == at::kDouble) {
AT_ERROR("promote type is double in at::autocast::prioritize");
return current;
}
if (is_autocast_eligible(nextArg)) {
at::ScalarType lower_precision_fp =
get_lower_precision_fp_from_device_type(device_type);
if (is_autocast_eligible(nextArg, device_type)) {
auto next = nextArg.scalar_type();
if (next == at::kDouble) {
return current; // ignores double tensors
} else if (current == at::kFloat || next == at::kFloat) {
return at::kFloat; // prioritizes float over half
} else if (current == at::kHalf && next == at::kHalf) {
return at::kHalf;
return at::kFloat; // prioritizes float over lower_precision_fp
} else if (current == lower_precision_fp && next == lower_precision_fp) {
return lower_precision_fp;
} else {
AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize");
return current;
@ -46,64 +83,92 @@ inline at::ScalarType prioritize(at::ScalarType current, const Tensor& nextArg)
// Overload to catch TensorList args (for e.g. cat, stack).
// Reuses the overload above to process each Tensor in the list.
inline at::ScalarType prioritize(at::ScalarType current, const TensorList& list) {
inline at::ScalarType prioritize(
at::ScalarType current,
const TensorList& list,
DeviceType device_type=DeviceType::CUDA) {
for (const auto& tensor : list) {
current = prioritize(current, tensor);
current = prioritize(current, tensor, device_type);
}
return current;
}
// Template to catch non-Tensor args (no-op that returns current best guess)
template<typename T>
inline at::ScalarType prioritize(at::ScalarType current, T nextArg) {
inline at::ScalarType prioritize(
at::ScalarType current,
T nextArg,
DeviceType device_type=DeviceType::CUDA) {
return current;
}
// Overload for the tail case.
inline at::ScalarType promote_type(at::ScalarType current) {
inline at::ScalarType promote_type(
at::ScalarType current,
DeviceType device_type) {
return current;
}
// Unpack args and determine if incoming float16 tensors need to be promoted to float32.
// Unpack args and determine if incoming lower_precision_fp tensors need to be promoted to float32.
// Non-Tensor arguments are ignored.
template<typename Arg0, typename... Args>
inline at::ScalarType promote_type(at::ScalarType current, Arg0 arg0, Args... args) {
auto new_current = prioritize(current, arg0);
return promote_type(new_current, args...);
inline at::ScalarType promote_type(
at::ScalarType current,
DeviceType device_type,
Arg0 arg0,
Args... args) {
auto new_current = prioritize(current, arg0, device_type);
return promote_type(new_current, device_type, args...);
}
/****************************************************
Logic to apply cached casting to any Tensor argument.
****************************************************/
inline bool is_eligible(const Tensor& arg) {
return (arg.defined() && is_autocast_eligible(arg) && (arg.scalar_type() != at::kDouble));
inline bool is_eligible(
const Tensor& arg,
DeviceType device_type=DeviceType::CUDA) {
return (arg.defined() &&
is_autocast_eligible(arg, device_type) &&
(arg.scalar_type() != at::kDouble));
}
// Overload to catch Tensor args
TORCH_API Tensor cached_cast(at::ScalarType to_type, const Tensor& arg);
TORCH_API Tensor cached_cast(
at::ScalarType to_type,
const Tensor& arg,
DeviceType device_type=DeviceType::CUDA);
// Overload to process optional<Tensor>
inline c10::optional<Tensor> cached_cast(at::ScalarType to_type, const c10::optional<Tensor>& arg) {
inline c10::optional<Tensor> cached_cast(
at::ScalarType to_type,
const c10::optional<Tensor>& arg,
DeviceType device_type=DeviceType::CUDA) {
if (arg.has_value()) {
return cached_cast(to_type, *arg);
return cached_cast(to_type, *arg, device_type);
} else {
return c10::nullopt;
}
}
// Overload to process TensorLists
inline std::vector<Tensor> cached_cast(at::ScalarType to_type, const TensorList& arg) {
inline std::vector<Tensor> cached_cast(
at::ScalarType to_type,
const TensorList& arg,
DeviceType device_type=DeviceType::CUDA) {
std::vector<Tensor> vec;
vec.reserve(arg.size());
for (const auto& t : arg) {
vec.push_back(cached_cast(to_type, t));
vec.push_back(cached_cast(to_type, t, device_type));
}
return vec;
}
// Template to catch non-Tensor args.
template<typename T>
inline T cached_cast(at::ScalarType to_type, T arg) {
inline T cached_cast(
at::ScalarType to_type,
T arg,
DeviceType device_type=DeviceType::CUDA) {
return arg;
}

View File

@ -228,7 +228,7 @@ enum class DispatchKey : uint8_t {
// Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
// and inputs are saved for backward in the post-autocast type.
// AutocastCPU,
AutocastCPU,
AutocastCUDA,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //

View File

@ -78,8 +78,8 @@ DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
DispatchKeySet getAutocastRelatedKeySetFromBackend(DispatchKey t) {
switch (t) {
// case DispatchKey::CPU:
// return DispatchKeySet(DispatchKey::AutocastCPU);
case DispatchKey::CPU:
return DispatchKeySet(DispatchKey::AutocastCPU);
case DispatchKey::CUDA:
return DispatchKeySet(DispatchKey::AutocastCUDA);
default:

View File

@ -223,7 +223,7 @@ constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
});
constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
// DispatchKey::AutocastCPU,
DispatchKey::AutocastCPU,
DispatchKey::AutocastCUDA,
});
@ -234,7 +234,7 @@ constexpr DispatchKeySet default_included_set = DispatchKeySet({
});
constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
// DispatchKey::AutocastCPU,
DispatchKey::AutocastCPU,
DispatchKey::AutocastCUDA,
});

View File

@ -428,4 +428,8 @@ inline std::ostream& operator<<(
return stream << toString(scalar_type);
}
#define AT_FORAUTOCAST_SCALAR_TYPES(_) \
_(half, Half) /* 0 */ \
_(bfloat16, BFloat16) /* 1 */
} // namespace c10

123
test/test_autocast.py Normal file
View File

@ -0,0 +1,123 @@
import collections
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.autocast_test_lists import AutocastCPUTestLists
class TestAutocastCPU(TestCase):
def setUp(self):
super(TestAutocastCPU, self).setUp()
self.autocast_lists = AutocastCPUTestLists(torch.device('cpu'))
def tearDown(self):
del self.autocast_lists
super(TestAutocastCPU, self).tearDown()
def _run_autocast_outofplace(self, op, args, run_as_type, out_type=None, module=torch, add_kwargs=None):
# helper to cast args
def cast(val, to_type):
if isinstance(val, torch.Tensor):
return val.to(to_type) if val.is_floating_point() else val
elif isinstance(val, collections.abc.Iterable):
return type(val)(cast(v, to_type) for v in val)
else:
return val
if add_kwargs is None:
add_kwargs = {}
self.assertFalse(torch.is_autocast_cpu_enabled())
with torch.cpu.amp.autocast():
self.assertTrue(torch.is_autocast_cpu_enabled())
out_type = out_type if out_type is not None else run_as_type
output = output_method = None
# Try module.* variant, if requested:
if module is not None and hasattr(module, op):
output = getattr(module, op)(*args, **add_kwargs)
if isinstance(output, torch.Tensor):
self.assertTrue(out_type == output.dtype,
"autocast for torch.{} produced {}, should produce {}"
.format(op, output.dtype, out_type))
# Try Tensor.* variant:
if hasattr(torch.Tensor, op):
output_method = getattr(args[0], op)(*args[1:], **add_kwargs)
if isinstance(output_method, torch.Tensor):
self.assertTrue(out_type == output_method.dtype,
"autocast for torch.{} produced {}, should produce torch.{}"
.format(op, output_method.dtype, out_type))
self.assertTrue((output is not None) or (output_method is not None),
"{} not found as an attribute on either Tensor or the requested module {}".format(
op, module))
# Accounts for ops that return Tensors, iterables, and other non-Tensors.
# For example, lstm_cell returns a tuple and equal returns bool.
def compare(first, second):
if isinstance(first, torch.Tensor):
return torch.equal(first, second)
elif isinstance(first, collections.abc.Iterable):
return all(compare(f, s) for f, s in zip(first, second))
else:
return first == second
# If both torch.* and Tensor.* variants were found, check outputs are identical
if (output is not None) and (output_method is not None):
self.assertTrue(type(output) == type(output_method))
comparison = compare(output, output_method)
self.assertTrue(comparison, "torch.{0} result did not match Tensor.{0} result".format(op))
# Compare numerics to Python-side "autocasting" that (we expect) does the same thing
# as the C++-side autocasting, and should be bitwise accurate.
output_to_compare = output if output is not None else output_method
with torch.cpu.amp.autocast(enabled=False):
self.assertFalse(torch.is_autocast_cpu_enabled())
if module is not None and hasattr(module, op):
control = getattr(module, op)(*cast(args, run_as_type), **add_kwargs)
else:
control = getattr(args[0].to(run_as_type), op)(*cast(args[1:], run_as_type), **add_kwargs)
self.assertTrue(type(output_to_compare) == type(control))
comparison = compare(output_to_compare, control)
self.assertTrue(comparison, "torch.{} result did not match control".format(op))
self.assertTrue(torch.is_autocast_cpu_enabled())
self.assertFalse(torch.is_autocast_cpu_enabled())
def args_maybe_kwargs(self, op_with_args):
if len(op_with_args) == 2:
return op_with_args[0], op_with_args[1], {}
else:
return op_with_args[0], op_with_args[1], op_with_args[2]
def test_autocast_torch_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.torch_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, out_type=out_type)
def test_autocast_methods_expect_builtin_promote(self):
for op, args, out_type in self.autocast_lists.methods_expect_builtin_promote:
self._run_autocast_outofplace(op, args, torch.float32, module=None, out_type=out_type)
def test_autocast_torch_bf16(self):
for op_with_args in self.autocast_lists.torch_bf16:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.bfloat16, add_kwargs=maybe_kwargs)
def test_autocast_nn_bf16(self):
for op, args in self.autocast_lists.nn_bf16:
self._run_autocast_outofplace(op, args, torch.bfloat16, module=torch._C._nn)
def test_autocast_torch_fp32(self):
for op_with_args in self.autocast_lists.torch_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, add_kwargs=maybe_kwargs)
def test_autocast_nn_fp32(self):
for op_with_args in self.autocast_lists.nn_fp32:
op, args, maybe_kwargs = self.args_maybe_kwargs(op_with_args)
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn, add_kwargs=maybe_kwargs)
def test_autocast_torch_need_autocast_promote(self):
for op, args in self.autocast_lists.torch_need_autocast_promote:
self._run_autocast_outofplace(op, args, torch.float32)
if __name__ == '__main__':
run_tests()

View File

@ -612,6 +612,10 @@ def is_inference_mode_enabled() -> _bool: ...
def set_autocast_enabled(enabled: _bool) -> None: ...
def is_autocast_enabled() -> _bool: ...
def clear_autocast_cache() -> None: ...
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
def is_autocast_cpu_enabled() -> _bool: ...
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
def get_autocast_cpu_dtype() -> _dtype: ...
def autocast_increment_nesting() -> _int: ...
def autocast_decrement_nesting() -> _int: ...
def set_anomaly_enabled(enabled: _bool) -> None: ...

View File

@ -670,6 +670,7 @@ def _assert(condition, message):
# side effect of adding to the imported module's members for other users.
from torch import cuda as cuda
from torch import cpu as cpu
from torch import autograd as autograd
from torch.autograd import (
no_grad as no_grad,

1
torch/cpu/__init__.py Normal file
View File

@ -0,0 +1 @@
from . import amp

View File

@ -0,0 +1 @@
from .autocast_mode import autocast

View File

@ -0,0 +1,36 @@
import torch
import functools
import warnings
class autocast(object):
def __init__(self, enabled=True, dtype=torch.bfloat16):
supported_dtype = [torch.bfloat16]
if dtype not in supported_dtype :
warnings.warn("In CPU autocast, but the target dtype is not supported. Disable the autocast.")
warnings.warn("CPU Autocast only support dtype of torch.bfloat16 currently.")
enabled = False
dtype = torch.bfloat16
self._enabled = enabled
self._dtype = dtype
def __enter__(self):
self.prev = torch.is_autocast_cpu_enabled()
self.prev_dtype = torch.get_autocast_cpu_dtype()
torch.set_autocast_cpu_enabled(self._enabled)
torch.set_autocast_cpu_dtype(self._dtype)
torch.autocast_increment_nesting()
def __exit__(self, *args):
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_cpu_enabled(self.prev)
torch.set_autocast_cpu_dtype(self.prev_dtype)
return False
def __call__(self, func):
@functools.wraps(func)
def decorate_autocast(*args, **kwargs):
with self:
return func(*args, **kwargs)
return decorate_autocast

View File

@ -13,6 +13,7 @@
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <c10/core/ScalarType.h>
#include <set>
@ -278,6 +279,57 @@ static PyObject * is_autocast_enabled(PyObject* _unused, PyObject *arg) {
END_HANDLE_TH_ERRORS
}
static PyObject * set_autocast_cpu_enabled(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (!PyBool_Check(arg)) {
throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
}
at::autocast::set_cpu_enabled(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject * is_autocast_cpu_enabled(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (at::autocast::is_cpu_enabled()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject * set_autocast_cpu_dtype(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (!THPDtype_Check(arg)) {
throw TypeError(
"dtype must be a torch.dtype (got %s)", Py_TYPE(arg)->tp_name);
}
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
at::autocast::set_autocast_cpu_dtype(targetType);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static const char* scalarTypeName(const at::ScalarType type) {
switch (type) {
#define DEFINE_CASE(ctype, name) \
case at::ScalarType::name: \
return #ctype;
AT_FORAUTOCAST_SCALAR_TYPES(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("unknown scalar type for autocast");
}
}
static PyObject * get_autocast_cpu_dtype(PyObject* _unused, PyObject *arg){
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype();
return THPDtype_New(current_dtype, scalarTypeName(current_dtype));
END_HANDLE_TH_ERRORS
}
static PyObject * clear_autocast_cache(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
at::autocast::clear_cache();
@ -377,6 +429,10 @@ static PyMethodDef methods[] = { // NOLINT
{"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
{"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
{"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},
{"set_autocast_cpu_dtype", set_autocast_cpu_dtype, METH_O, nullptr},
{"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr},
{"autocast_increment_nesting", autocast_increment_nesting, METH_NOARGS, nullptr},
{"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr},
{"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr},

View File

@ -184,6 +184,10 @@ def get_ignored_functions() -> Set[Callable]:
torch.set_autocast_enabled,
torch.is_autocast_enabled,
torch.clear_autocast_cache,
torch.set_autocast_cpu_enabled,
torch.is_autocast_cpu_enabled,
torch.set_autocast_cpu_dtype,
torch.get_autocast_cpu_dtype,
torch.autocast_increment_nesting,
torch.autocast_decrement_nesting,
torch.nn.functional.hardswish,

View File

@ -237,3 +237,123 @@ class AutocastTestLists(object):
("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.float32),
torch.rand((n, n), device=dev, dtype=torch.float32)), torch._C._nn),
]
class AutocastCPUTestLists(object):
# Supplies ops and arguments for test_autocast_* in test/test_cpu.py
def __init__(self, dev):
super().__init__()
n = 8
# Utility arguments, created as one-element tuples
pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),)
mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),)
dummy_dimsets = ((n,), (n, n), (n, n, n), (n, n, n, n), (n, n, n, n, n))
dummy_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),)
for dimset in dummy_dimsets]
dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n))
conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev),
torch.randn(dimset, dtype=torch.bfloat16, device=dev))
for dimset in dimsets]
conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),
torch.randn(dimset, dtype=torch.float32, device=dev))
for dimset in dimsets]
bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),)
element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),)
pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),)
mat0_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat1_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),)
dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),)
for dimset in dummy_dimsets]
# The lists below organize ops that autocast needs to test.
# self.list_name corresponds to test_autocast_list_name in test/test_cpu.py.
# Each op is associated with a tuple of valid arguments.
# Some ops implement built-in type promotion. These don't need autocasting,
# but autocasting relies on their promotion, so we include tests to double-check.
self.torch_expect_builtin_promote = [
("eq", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("ge", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("gt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("le", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("lt", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("ne", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("add", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("div", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("mul", pointwise0_fp32 + pointwise1_bf16, torch.float32),
]
self.methods_expect_builtin_promote = [
("__eq__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__ge__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__gt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__le__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__lt__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__ne__", pointwise0_fp32 + pointwise1_bf16, torch.bool),
("__add__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("__div__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
("__mul__", pointwise0_fp32 + pointwise1_bf16, torch.float32),
]
# The remaining lists organize ops that autocast treats explicitly.
self.torch_bf16 = [
("conv1d", conv_args_fp32[0]),
("conv2d", conv_args_fp32[1]),
("conv3d", conv_args_fp32[2]),
("log_softmax", pointwise0_fp32 + (0,)),
("bmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("mm", mat0_fp32 + mat1_fp32),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("addmm", mat1_fp32 + mat2_fp32 + mat3_fp32),
("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
]
self.torch_fp32 = [
("conv_transpose3d", conv_args_bf16[2]),
("batch_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32),
"running_var": torch.rand((n), dtype=torch.float32), "training": False,
"momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}),
("max_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}),
("dropout", dummy_bf16[2], {"p": 0.1, "train": False}),
("binary_cross_entropy_with_logits", mat0_bf16 + (torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
("pow", ((pointwise0_bf16[0] + 1.).clamp(0.0, 100.0),) + pointwise1_bf16),
("pow", ((pointwise0_bf16[0] + 1.).clamp(0.0, 100.0),) + (1.7,)),
("instance_norm", dummy_bf16[2], {"weight": None, "bias": None, "running_mean": torch.rand((n), dtype=torch.float32),
"running_var": torch.rand((n), dtype=torch.float32), "use_input_stats": False,
"momentum": 0.1, "eps": 1e-5, "cudnn_enabled": False}),
]
self.nn_bf16 = [
("linear", mat0_fp32 + mat1_fp32),
]
self.nn_fp32 = [
("adaptive_avg_pool2d", dummy_bf16[2], {"output_size": (3, 2)}),
("avg_pool2d", dummy_bf16[2], {"kernel_size": (3, 2), "stride": (1, 1)}),
("avg_pool3d", dummy_bf16[3], {"kernel_size": (3, 3, 3), "stride": (1, 1, 1)}),
("gelu", dummy_bf16[3]),
("upsample_nearest1d", dummy_bf16[2], {"output_size": (n)}),
("upsample_nearest2d", dummy_bf16[3], {"output_size": (n, n)}),
("upsample_nearest3d", dummy_bf16[4], {"output_size": (n, n, n)}),
("upsample_linear1d", dummy_bf16[2], {"output_size": (n), "align_corners": False}),
("upsample_bilinear2d", dummy_bf16[3], {"output_size": (n, n), "align_corners": False}),
("upsample_trilinear3d", dummy_bf16[4], {"output_size": (n, n, n), "align_corners": False}),
("binary_cross_entropy", (torch.rand((n, n), device=dev, dtype=torch.bfloat16),) +
(torch.rand((n, n), device=dev, dtype=torch.bfloat16),)),
("smooth_l1_loss", mat0_bf16 + mat1_bf16),
("reflection_pad1d", dummy_bf16[2], {"padding": (3, 3)}),
("std", dummy_bf16[2]),
]
self.torch_need_autocast_promote = [
("cat", (pointwise0_bf16 + pointwise1_fp32,)),
("stack", (pointwise0_bf16 + pointwise1_fp32,)),
]