mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
# Motivation As discussed in [#124479](https://github.com/pytorch/pytorch/pull/124479), `torch.amp.autocast` can NOT be completely equivalent to `torch.cuda.amp.autocast` and `torch.cpu.amp.autocast` since `torch.amp.autocast` has NOT the default `dtype` for CPU (`torch.bfloat16` by default) and CUDA (`torch.float16` by default) respectively. We would like `torch.amp.autocast` to be more generic to help the developer/customer write the device-agnostic code. Because there are not enough reasons to add device-specific autocast `torch.xxx.amp.autocast` for each device backend. # Solution When `None` is passed to `dtype`, we should use `torch.get_autocast_dtype` to get the related dtype for each backend. Meanwhile, `torch.get_autocast_dtype` is necessary to be supported in JIT path for BC. # Additional Context With this PR, `torch.amp.autocast(device_type='cuda')` is equivalent to `torch.cuda.amp.autocast`. Add two new UTs to cover this change in eager and jit path respectively. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125103 Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui
537 lines
18 KiB
C++
537 lines
18 KiB
C++
|
|
#include <torch/csrc/jit/passes/autocast.h>
|
|
|
|
#include <ATen/autocast_mode.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Optional.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/jit_log.h>
|
|
#include <torch/csrc/jit/passes/quantization/helper.h>
|
|
|
|
#include <stack>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
namespace {
|
|
|
|
bool autocast_enabled = true;
|
|
|
|
struct AutocastContext {
|
|
bool gpu_enabled = false;
|
|
bool cpu_enabled = false;
|
|
c10::ScalarType gpu_scalar_type = c10::ScalarType::Undefined;
|
|
c10::ScalarType cpu_scalar_type = c10::ScalarType::Undefined;
|
|
|
|
operator bool() const {
|
|
return gpu_enabled || cpu_enabled;
|
|
}
|
|
};
|
|
|
|
struct AutocastScope {
|
|
Value* instance = nullptr;
|
|
AutocastContext context;
|
|
void stack(const AutocastContext& parent_context) {}
|
|
};
|
|
|
|
bool isAutocastNode(Value* value) {
|
|
const auto class_name = getModuleName(value);
|
|
return class_name.has_value() &&
|
|
(*class_name == "__torch__.torch.cuda.amp.autocast_mode.autocast" ||
|
|
*class_name == "__torch__.torch.cpu.amp.autocast_mode.autocast" ||
|
|
*class_name == "__torch__.torch.amp.autocast_mode.autocast");
|
|
}
|
|
|
|
// If we have an autocast instance, return it
|
|
//
|
|
// This is the pattern we're looking for (this is done after
|
|
// autocast.__init__() has been inlined)
|
|
//
|
|
// %4 : bool = prim::Constant[value=1]()
|
|
// %5 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
|
|
// = prim::SetAttr[name="_enabled"](%5, %4)
|
|
//
|
|
// Notes:
|
|
// 1. There's no guarantee that the autocast instance is in the same block
|
|
// as the prim::Enter() node
|
|
// 2. `prim::SetAttr` must follow `prim::CreateObject()` in the same block,
|
|
// but there might be other nodes in between
|
|
//
|
|
c10::optional<AutocastScope> parseAutocast(
|
|
Value* value,
|
|
const AutocastContext& context) {
|
|
if (!isAutocastNode(value)) {
|
|
// Not an autocast...
|
|
return c10::nullopt;
|
|
}
|
|
if (value->node()->kind() == prim::CreateObject) {
|
|
AutocastScope scope;
|
|
scope.instance = value;
|
|
scope.context = context;
|
|
c10::optional<bool> enabled;
|
|
std::string device;
|
|
c10::ScalarType dtype = c10::ScalarType::Undefined;
|
|
for (Use use : value->uses()) {
|
|
// TODO: support runtime flag
|
|
if (use.user->kind() == prim::SetAttr &&
|
|
use.user->s(attr::name) == "_enabled") {
|
|
// Search for `prim::SetAttr[name="_enabled"]`
|
|
auto ret = constant_as<bool>(use.user->input(1));
|
|
TORCH_CHECK(
|
|
ret.has_value(), "Autocast _enabled argument must be a constant");
|
|
enabled = ret.value();
|
|
} else if (
|
|
use.user->kind() == prim::SetAttr &&
|
|
use.user->s(attr::name) == "device") {
|
|
// Search for `prim::SetAttr[name="device"]`
|
|
auto ret = constant_as<std::string>(use.user->input(1));
|
|
TORCH_CHECK(
|
|
ret.has_value(), "Autocast device argument must be a constant");
|
|
device = ret.value();
|
|
} else if (
|
|
use.user->kind() == prim::SetAttr &&
|
|
use.user->s(attr::name) == "fast_dtype") {
|
|
// Search for `prim::SetAttr[name="fast_dtype"]`
|
|
auto ret = constant_as<c10::ScalarType>(use.user->input(1));
|
|
if (ret.has_value()) {
|
|
dtype = ret.value();
|
|
}
|
|
}
|
|
}
|
|
TORCH_CHECK(enabled.has_value(), "Autocast missing _enabled attribute");
|
|
TORCH_CHECK(!device.empty(), "Autocast missing device attribute");
|
|
if (dtype == c10::ScalarType::Undefined) {
|
|
dtype = at::autocast::get_autocast_dtype(c10::Device(device).type());
|
|
}
|
|
TORCH_CHECK(
|
|
dtype != c10::ScalarType::Undefined,
|
|
"Autocast has invalid fast_dtype attribute");
|
|
if (device == "cuda") {
|
|
scope.context.gpu_enabled = enabled.value();
|
|
scope.context.gpu_scalar_type = dtype;
|
|
} else if (device == "cpu") {
|
|
scope.context.cpu_enabled = enabled.value();
|
|
scope.context.cpu_scalar_type = dtype;
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "unrecognized device for autocast pass: ", device);
|
|
}
|
|
return scope;
|
|
} else {
|
|
// We only support simple and static autocast expressions. For example,
|
|
// the following should report an error (since the autocast would not
|
|
// work as expected)
|
|
//
|
|
// autocast_on = autocast(enabled=True)
|
|
// autocast_off = autocast(enabled=False)
|
|
// with autocast_on if condition else autocast_off:
|
|
// ...
|
|
//
|
|
// TODO: better error message
|
|
//
|
|
AT_ERROR("Unsupported autocast syntax");
|
|
}
|
|
|
|
return c10::nullopt;
|
|
}
|
|
|
|
void castTensorInputs(
|
|
Node* node,
|
|
Symbol cast_op,
|
|
const AutocastContext& context) {
|
|
if (!context) {
|
|
return;
|
|
}
|
|
|
|
const auto graph = node->owningGraph();
|
|
|
|
std::unordered_set<Value*> casted_inputs;
|
|
// need to also keep the inputs in order, otherwise tracing fails
|
|
// sanity checks because casting ops are inserted in random order
|
|
std::vector<Value*> casted_inputs_ordered;
|
|
for (auto input : node->inputs()) {
|
|
// TODO: update cast_op signature to take dynamic context flags
|
|
auto input_tensor_type = input->type()->cast<TensorType>();
|
|
if (input_tensor_type && input->node()->kind() != cast_op) {
|
|
auto has_inserted = casted_inputs.insert(input);
|
|
if (has_inserted.second) {
|
|
casted_inputs_ordered.push_back(input);
|
|
}
|
|
}
|
|
}
|
|
|
|
WithInsertPoint insert_point(node);
|
|
|
|
for (auto input : casted_inputs_ordered) {
|
|
if (cast_op == aten::_autocast_to_full_precision) {
|
|
const auto new_input = graph->insert(
|
|
cast_op,
|
|
{input,
|
|
graph->insertConstant(IValue(context.gpu_enabled)),
|
|
graph->insertConstant(IValue(context.cpu_enabled))});
|
|
node->replaceInputWith(input, new_input);
|
|
} else if (cast_op == aten::_autocast_to_reduced_precision) {
|
|
const auto new_input = graph->insert(
|
|
cast_op,
|
|
{input,
|
|
graph->insertConstant(IValue(context.gpu_enabled)),
|
|
graph->insertConstant(IValue(context.cpu_enabled)),
|
|
graph->insertConstant(IValue(context.gpu_scalar_type)),
|
|
graph->insertConstant(IValue(context.cpu_scalar_type))});
|
|
node->replaceInputWith(input, new_input);
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(
|
|
false, "unrecognized cast_op symbol: ", cast_op.toQualString());
|
|
}
|
|
}
|
|
}
|
|
|
|
bool hasExplicitDtypeArgument(Node* node) {
|
|
if (node->hasNamedInput("dtype")) {
|
|
Value* dtype_arg = node->namedInput("dtype");
|
|
return dtype_arg->type()->kind() != TypeKind::NoneType;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void castInputsToWidestType(Node* node, const AutocastContext& context) {
|
|
if (!context) {
|
|
return;
|
|
}
|
|
// Figure out the widest type
|
|
// (really, just looking for any float32 inputs)
|
|
//
|
|
// TODO: revisit this (do we need to consider float64 types?)
|
|
//
|
|
for (auto input : node->inputs()) {
|
|
if (auto tensor_type = input->type()->cast<TensorType>()) {
|
|
const auto dtype = tensor_type->scalarType();
|
|
if (!dtype.has_value() || *dtype == at::ScalarType::Float) {
|
|
castTensorInputs(node, aten::_autocast_to_full_precision, context);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Users can call torch.is_autocast_enabled() or is_autocast_cpu_enabled() to
|
|
// determine whether autocasting is enabled. With JIT-scripted functions, we
|
|
// actually need to return true if eager autocast OR jit autocast are enabled.
|
|
//
|
|
// In the case where JIT autocast is enabled, we replace
|
|
// %x : bool = aten::is_autocast_enabled()
|
|
// with a constant "True".
|
|
//
|
|
// More context on eager vs JIT autocasting:
|
|
//
|
|
// Autocasting actually has two settings: eager autocasting, and JIT
|
|
// autocasting. Eager autocasting is the thread-local setting that turns on
|
|
// the relevant bit in the dispatcher settings. JIT autocasting is the pass
|
|
// implemented in this file, which makes changes to the graph to insert casting
|
|
// ops in order to achieve the same behavior as eager autocasting.
|
|
//
|
|
// If eager autocasting is enabled at the time when a JIT-scripted function is
|
|
// invoked, then autocasting will occur regardless of what the JIT-autocasting
|
|
// settings are.
|
|
void updateAutocastEnabledCheck(Node* node, bool is_jit_enabled) {
|
|
if (!is_jit_enabled) {
|
|
return;
|
|
}
|
|
|
|
auto graph = node->owningGraph();
|
|
|
|
WithInsertPoint insert_point(node);
|
|
|
|
Value* true_constant = graph->insertConstant(IValue(true));
|
|
node->output()->replaceAllUsesWith(true_constant);
|
|
node->destroy();
|
|
}
|
|
|
|
// [Note: implicit type promotion in Autocast]
|
|
//
|
|
// Casting policy below mostly follows pytorch/aten/src/ATen/autocast.cpp, with
|
|
// a few exceptions, e.g. `aten::add`, which is needed to be put to promotion
|
|
// list for JIT autocast.
|
|
// The reason is that in eager amp, some binary ops promote inputs implicitly
|
|
// inside the operation, e.g. `aten::add` with fp16 & fp32 inputs would both be
|
|
// casted to fp32. In backward, autograd would cast dgrad to match their
|
|
// scalar_type in forward graph. So inputs with mismatched scalar_type would
|
|
// get the different dgrad.
|
|
// While in JIT, autodiff doesn't do this, so implicit cast is not visible to
|
|
// autodiff and backward dgrad for mismatched inputs would ended up with dgrads
|
|
// in the same scalar_type. This has caused downstream operations, which
|
|
// expects dgrad to be the same scalar type to throw mismatch error.
|
|
//
|
|
// TODO: Use the list from AMP eager directly
|
|
void handleBlock(Block* block, AutocastContext initial_state) {
|
|
std::stack<AutocastScope> autocast_stack;
|
|
|
|
c10::optional<bool> incompatible_amp = c10::nullopt;
|
|
|
|
// The current autocast enabled/disabled state
|
|
auto current_state = [&] {
|
|
return autocast_stack.empty() ? initial_state
|
|
: autocast_stack.top().context;
|
|
};
|
|
|
|
for (Node* node : block->nodes()) {
|
|
switch (node->kind()) {
|
|
case prim::CallFunction:
|
|
// TODO: limit it only to amp related node;
|
|
if (current_state() == initial_state) {
|
|
// if the current autocasting state is the same as the global state,
|
|
// then autocasting will be done correctly on subsequent method and
|
|
// function calls
|
|
if (current_state()) {
|
|
castTensorInputs(
|
|
node, aten::_autocast_to_full_precision, current_state());
|
|
}
|
|
break;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
!incompatible_amp.has_value() || incompatible_amp.value(),
|
|
"Calls are not expected with AMP & JIT");
|
|
incompatible_amp = true;
|
|
break;
|
|
|
|
case prim::CallMethod:
|
|
// TODO: limit it only to amp related node;
|
|
if (current_state() == initial_state) {
|
|
// if the current autocasting state is the same as the global state,
|
|
// then autocasting will be done correctly on subsequent method and
|
|
// function calls
|
|
if (current_state()) {
|
|
castTensorInputs(
|
|
node, aten::_autocast_to_full_precision, current_state());
|
|
}
|
|
break;
|
|
}
|
|
if (auto class_type = node->input(0)->type()->cast<ClassType>()) {
|
|
const auto& name = node->s(attr::name);
|
|
const auto& function = class_type->getMethod(name);
|
|
if (!function.isGraphFunction()) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!incompatible_amp.has_value() || incompatible_amp.value(),
|
|
"Calls are not expected with AMP & JIT");
|
|
incompatible_amp = true;
|
|
}
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(
|
|
!incompatible_amp.has_value() || incompatible_amp.value(),
|
|
"Unexpected prim::CallMethod form with AMP & JIT");
|
|
incompatible_amp = true;
|
|
}
|
|
break;
|
|
|
|
case prim::Enter:
|
|
if (auto autocast_scope =
|
|
parseAutocast(node->input(), current_state())) {
|
|
if (node->hasUses()) {
|
|
// TODO: better error message
|
|
AT_ERROR("`with autocast() as ...` is not supported");
|
|
}
|
|
TORCH_INTERNAL_ASSERT(
|
|
!incompatible_amp.has_value() || !incompatible_amp.value(),
|
|
"Unsupported case by AMP & JIT");
|
|
incompatible_amp = false;
|
|
autocast_stack.push(*autocast_scope);
|
|
}
|
|
break;
|
|
|
|
case prim::Exit:
|
|
if (isAutocastNode(node->input(0))) {
|
|
TORCH_INTERNAL_ASSERT(!autocast_stack.empty());
|
|
TORCH_INTERNAL_ASSERT(autocast_stack.top().instance == node->input());
|
|
TORCH_INTERNAL_ASSERT(
|
|
!incompatible_amp.has_value() || !incompatible_amp.value(),
|
|
"Unsupported case by AMP & JIT");
|
|
incompatible_amp = false;
|
|
autocast_stack.pop();
|
|
}
|
|
break;
|
|
|
|
case aten::is_autocast_enabled:
|
|
updateAutocastEnabledCheck(node, current_state().gpu_enabled);
|
|
break;
|
|
|
|
case aten::is_autocast_cpu_enabled:
|
|
updateAutocastEnabledCheck(node, current_state().cpu_enabled);
|
|
break;
|
|
|
|
// CastPolicy::fp16 (cast all inputs to float16)
|
|
case aten::_convolution:
|
|
case aten::conv1d:
|
|
case aten::conv2d:
|
|
case aten::conv3d:
|
|
case aten::conv_tbc:
|
|
case aten::conv_transpose1d:
|
|
case aten::convolution:
|
|
case aten::cudnn_convolution:
|
|
case aten::cudnn_convolution_transpose:
|
|
case aten::prelu:
|
|
case aten::addmm:
|
|
case aten::addmv:
|
|
case aten::addr:
|
|
case aten::matmul:
|
|
case aten::mm:
|
|
case aten::mv:
|
|
case aten::linear:
|
|
case aten::addbmm:
|
|
case aten::baddbmm:
|
|
case aten::bmm:
|
|
case aten::chain_matmul:
|
|
case aten::_thnn_fused_lstm_cell:
|
|
case aten::_thnn_fused_gru_cell:
|
|
case aten::lstm_cell:
|
|
case aten::gru_cell:
|
|
case aten::rnn_tanh_cell:
|
|
case aten::rnn_relu_cell:
|
|
if (!node->schema().is_mutable()) {
|
|
castTensorInputs(
|
|
node, aten::_autocast_to_reduced_precision, current_state());
|
|
}
|
|
break;
|
|
|
|
// CastPolicy::fp32 (cast all inputs to float32)
|
|
case aten::native_layer_norm:
|
|
case aten::acos:
|
|
case aten::asin:
|
|
case aten::cosh:
|
|
case aten::erfinv:
|
|
case aten::exp:
|
|
case aten::expm1:
|
|
case aten::log:
|
|
case aten::log10:
|
|
case aten::log2:
|
|
case aten::log1p:
|
|
case aten::reciprocal:
|
|
case aten::rsqrt:
|
|
case aten::sinh:
|
|
case aten::tan:
|
|
case aten::pow:
|
|
case aten::softplus:
|
|
case aten::gelu:
|
|
case aten::layer_norm:
|
|
case aten::group_norm:
|
|
case aten::frobenius_norm:
|
|
case aten::nuclear_norm:
|
|
case aten::cosine_similarity:
|
|
case aten::cosine_embedding_loss:
|
|
case aten::nll_loss:
|
|
case aten::nll_loss2d:
|
|
case aten::hinge_embedding_loss:
|
|
case aten::kl_div:
|
|
case aten::l1_loss:
|
|
case aten::smooth_l1_loss:
|
|
case aten::mse_loss:
|
|
case aten::margin_ranking_loss:
|
|
case aten::multilabel_margin_loss:
|
|
case aten::soft_margin_loss:
|
|
case aten::triplet_margin_loss:
|
|
case aten::multi_margin_loss:
|
|
case aten::binary_cross_entropy_with_logits:
|
|
case aten::dist:
|
|
case aten::pdist:
|
|
case aten::cdist:
|
|
case aten::renorm:
|
|
case aten::logsumexp:
|
|
if (!node->schema().is_mutable()) {
|
|
castTensorInputs(
|
|
node, aten::_autocast_to_full_precision, current_state());
|
|
}
|
|
break;
|
|
|
|
// CastPolicy::fp32_set_opt_dtype
|
|
case aten::prod:
|
|
case aten::log_softmax:
|
|
case aten::cumprod:
|
|
case aten::cumsum:
|
|
case aten::sum:
|
|
if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
|
|
castTensorInputs(
|
|
node, aten::_autocast_to_full_precision, current_state());
|
|
}
|
|
break;
|
|
|
|
// cast softmax to fp32 only on GPU
|
|
case aten::softmax:
|
|
if (!node->schema().is_mutable() && !hasExplicitDtypeArgument(node)) {
|
|
auto context = current_state();
|
|
context.cpu_enabled = false;
|
|
castTensorInputs(node, aten::_autocast_to_full_precision, context);
|
|
}
|
|
break;
|
|
|
|
// CastPolicy::promote (promote inputs to the widest type)
|
|
case aten::addcdiv:
|
|
case aten::addcmul:
|
|
case aten::atan2:
|
|
case aten::bilinear:
|
|
case aten::cat:
|
|
case aten::cross:
|
|
case aten::dot:
|
|
case aten::equal:
|
|
case aten::index_put:
|
|
case aten::stack:
|
|
case aten::tensordot:
|
|
// add, sub, mul, div were added to autocast jit, because aten implicit
|
|
// type promotion is not visible to JIT and could cause dtype mismatch on
|
|
// backward
|
|
// see [Note: implicit type promotion in Autocast]
|
|
case aten::add:
|
|
case aten::sub:
|
|
case aten::mul:
|
|
case aten::div:
|
|
if (!node->schema().is_mutable()) {
|
|
castInputsToWidestType(node, current_state());
|
|
}
|
|
break;
|
|
|
|
// Banned in autocast, see binary_cross_entropy_banned()
|
|
case aten::binary_cross_entropy:
|
|
if (current_state()) {
|
|
AT_ERROR("Unsafe to autocast");
|
|
}
|
|
}
|
|
|
|
// process sub-blocks, if any
|
|
for (Block* sub_block : node->blocks()) {
|
|
handleBlock(sub_block, current_state());
|
|
}
|
|
}
|
|
|
|
// Sanity check: make sure there's no unbalanced transition
|
|
TORCH_INTERNAL_ASSERT(autocast_stack.empty());
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool setAutocastMode(bool value) {
|
|
auto old_value = autocast_enabled;
|
|
autocast_enabled = value;
|
|
return old_value;
|
|
}
|
|
|
|
bool autocastEnabled() {
|
|
return autocast_enabled;
|
|
}
|
|
|
|
void Autocast(const std::shared_ptr<Graph>& graph) {
|
|
GRAPH_DUMP("\nBefore Autocast: ", graph);
|
|
if (autocastEnabled()) {
|
|
AutocastContext init = {
|
|
at::autocast::is_autocast_enabled(at::kCUDA),
|
|
at::autocast::is_autocast_enabled(at::kCPU),
|
|
at::autocast::get_autocast_dtype(at::kCUDA),
|
|
at::autocast::get_autocast_dtype(at::kCPU)};
|
|
handleBlock(graph->block(), init);
|
|
}
|
|
GRAPH_DUMP("\nAfter Autocast: ", graph);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|