Files
pytorch/torch/csrc/jit/codegen/fuser/codegen.cpp
Jeff Daily ae9a4fa63c [ROCm] enforce ROCM_VERSION >= 6.0 (#125646)
Remove any code relying on ROCM_VERSION < 6.0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125646
Approved by: https://github.com/albanD, https://github.com/eqy
2024-05-12 18:01:28 +00:00

691 lines
24 KiB
C++

#include <torch/csrc/jit/codegen/fuser/codegen.h>
#include <ATen/ATen.h>
#include <ATen/code_template.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/codegen/fuser/compiler.h>
#include <torch/csrc/jit/codegen/fuser/interface.h>
#include <torch/csrc/jit/codegen/fuser/tensor_info.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/codegen/fuser/cpu/resource_strings.h>
#include <torch/csrc/jit/codegen/fuser/cuda/resource_strings.h>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <sstream>
#include <tuple>
#include <vector>
namespace torch {
namespace jit {
namespace fuser {
// Template for computing the offset into the tensor to access a value
static auto dim_calc = at::jit::CodeTemplate(R"(
//printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
)");
static std::string valueName(const Value* n) {
return "n" + c10::to_string(n->unique());
}
static std::string scalarValue(const int64_t v) {
return c10::to_string(v);
}
static std::string scalarValue(const bool v) {
return c10::to_string(v);
}
// Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific
// implementations of these special values. These macros are found in the
// resource strings for each device.
static std::string scalarValue(const double v) {
std::ostringstream out;
if (std::isnan(v)) {
out << "NAN";
} else if (std::isinf(v)) {
if (v < 0) {
out << "NEG_INFINITY";
} else {
out << "POS_INFINITY";
}
} else {
out << std::setprecision(16) << v;
}
return out.str();
}
// Note: Half is special-cased to avoid returning at::Half
static const char* scalarTypeName(const at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "half";
}
if (type == at::ScalarType::BFloat16) {
return "__nv_bfloat16";
}
switch (type) {
#define DEFINE_CASE(ctype, name) \
case at::ScalarType::name: \
return #ctype;
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_CASE)
#undef DEFINE_CASE
default:
throw std::runtime_error("unknown scalar type");
}
}
static const char* calcScalarTypeName(const at::ScalarType type) {
if (type == at::ScalarType::Half) {
return "float";
}
if (type == at::ScalarType::BFloat16) {
return "float";
}
return scalarTypeName(type);
}
static std::string variableType(const c10::Type& t) {
if (t.kind() == TypeKind::IntType) {
return "int64_t";
} else if (t.kind() == TypeKind::FloatType) {
return "double";
} else if (t.kind() == TypeKind::BoolType) {
return "bool";
} else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
return calcScalarTypeName(*scalar_type);
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
"unknown scalar type during JIT fusion code generation");
}
static std::string typeCastedValueName(
const c10::Type& t,
const at::ScalarType outtype,
const std::string& vn) {
if (t.kind() == TypeKind::IntType || t.kind() == TypeKind::BoolType) {
if (!isIntegralType(outtype, /*includeBool=*/false)) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
} else if (t.kind() == TypeKind::FloatType) {
// We don't guard this on anything because in our type system for scalars,
// there is not a distinction between `float` and `double`, however there
// *is* a distinction in tensor scalar types. We conservatively insert a
// cast here, which may end up being a no-op if the tensor's scalar type
// is `double`.
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
} else if (t.kind() == TypeKind::NoneType) {
// Support None value for optional arguments like memory format
return vn;
} else if (auto scalar_type = t.expectRef<TensorType>().scalarType()) {
if (*scalar_type != outtype) {
return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
}
return vn;
}
// something went wrong with the type analysis during shape propagation
throw std::runtime_error(
"unknown scalar type during JIT fusion code generation");
}
// Writes RHS of special handling "simple mappable" ops
static std::string encodeSpecialRHS(const Node* n, at::jit::TemplateEnv& env) {
// special case for clamp fusion on missing min/max inputs
// Note: It may seem unusual to have the bounds as the first case below,
// this is so that if min or max is NaN, they are "ignored"
// and when the input is NaN, the output is, too
if (n->kind() == aten::clamp) {
const auto min = n->input(1);
const auto max = n->input(2);
env.s("0", valueName(n->input(0)));
if (!min->node()->mustBeNone() && !max->node()->mustBeNone()) {
env.s("1", valueName(min));
env.s("2", valueName(max));
return format("(${0} < ${1} ? ${1} : (${0} > ${2}? ${2} : ${0}))", env);
} else if (min->node()->mustBeNone()) {
env.s("1", valueName(max));
return format("(${0} > ${1} ? ${1} : ${0})", env);
} else if (max->node()->mustBeNone()) {
env.s("1", valueName(min));
return format("(${0} < ${1} ? ${1} : ${0})", env);
} else {
throw std::runtime_error(
"At least one of 'min' or 'max' must not be None");
}
} else {
throw std::runtime_error("Cannot encode RHS of the node, op not supported");
}
}
// This struct specifies a template for dispatching specific aten:: operators.
// The current variants of RHS code selection we support are for double and
// float output values. For example, an aten::log operator which is assigned
// to a float value would emit logf(), whereas an aten::log operator which is
// assigned to a double would emit log().
struct RHSTemplate {
// Common case: float and double dispatch are identical
RHSTemplate(const char* for_float)
: for_float(for_float), for_double(for_float) {}
RHSTemplate(const char* for_float, const char* for_double)
: for_float(for_float), for_double(for_double) {}
const char* for_float;
const char* for_double;
};
// Writes "simple mappable" ops
static std::string encodeRHS(const Node* n) {
static std::unordered_map<NodeKind, RHSTemplate> simple_map_ops = {
// unary
{aten::_cast_Float, "static_cast<float>(${0})"},
{aten::abs, "fabs(${0})"},
{aten::sigmoid, {"1.f / (1.f + expf(-${0}))", "1. / (1. + exp(-${0}))"}},
{aten::relu, "${0} < 0 ? 0.f : ${0} "},
{aten::threshold,
"${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
{aten::log, {"logf(${0})", "log(${0})"}},
{aten::log10, {"log10f(${0})", "log10(${0})"}},
{aten::log1p, {"log1pf(${0})", "log1p(${0})"}},
{aten::log2, {"log2f(${0})", "log2(${0})"}},
{aten::lgamma, {"lgammaf(${0})", "lgamma(${0})"}},
{aten::exp, {"expf(${0})", "exp(${0})"}},
{aten::expm1, {"expm1f(${0})", "expm1(${0})"}},
{aten::erf, {"erff(${0})", "erf(${0})"}},
{aten::erfc, {"erfcf(${0})", "erfc(${0})"}},
{aten::cos, {"cosf(${0})", "cos(${0})"}},
{aten::acos, {"acosf(${0})", "acos(${0})"}},
{aten::cosh, {"coshf(${0})", "cosh(${0})"}},
{aten::sin, {"sinf(${0})", "sin(${0})"}},
{aten::asin, {"asinf(${0})", "asin(${0})"}},
{aten::sinh, {"sinhf(${0})", "sinh(${0})"}},
{aten::tan, {"tanf(${0})", "tan(${0})"}},
{aten::atan, {"atanf(${0})", "atan(${0})"}},
{aten::tanh, {"tanhf(${0})", "tanh(${0})"}},
{aten::sqrt, {"sqrtf(${0})", "sqrt(${0})"}},
{aten::rsqrt, {"rsqrtf(${0})", "rsqrt(${0})"}},
{aten::ceil, {"ceilf(${0})", "ceil(${0})"}},
{aten::floor, {"floorf(${0})", "floor(${0})"}},
{aten::round, {"roundf(${0})", "round(${0})"}},
{aten::trunc, {"truncf(${0})", "trunc(${0})"}},
{aten::frac, {"${0} - truncf(${0})", "${0} - trunc(${0})"}},
{aten::reciprocal, {"1.f/(${0})", "1./(${0})"}},
{aten::neg, "-${0}"},
// simple binary
{aten::atan2, "atan2(${0}, ${1})"},
{aten::min,
"isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${0} : ${1}))"},
{aten::max,
"isnan(${0}) ? ${0} : (isnan(${1}) ? ${1} : (${0} < ${1} ? ${1} : ${0}))"},
// binary with other
// TODO: some of these ops will not get generated because
// we only work on float inputs/outputs, but they are here to record
// that they are valid mappable ops once we handle more type
{aten::__and__, "${0} && ${1}"},
{aten::__lshift__, "${0} << ${1}"},
{aten::__or__, "${0} || ${1}"},
{aten::__rshift__, "${0} >> ${1}"},
{aten::__xor__, "${0} ^ ${1}"},
{aten::addcmul, "${0} + ${3} * ${1} * ${2}"},
{aten::div, "${0} / ${1}"},
{aten::eq, "${0_nocast} == ${1_nocast}"},
{aten::fmod, "fmodf(${0}, ${1})"},
{aten::ge, "(${0_nocast} >= ${1_nocast})"},
{aten::gt, "${0_nocast} > ${1_nocast}"},
{aten::le, "(${0_nocast} <= ${1_nocast})"},
{aten::lt, "${0_nocast} < ${1_nocast}"},
{aten::lerp, "${0} + ${2} * (${1} - ${0})"},
{aten::type_as, "(${0})"},
{aten::mul, "${0} * ${1}"},
{aten::ne, "${0_nocast} != ${1_nocast}"},
{aten::remainder, "fmod((${1} + fmod(${0}, ${1})), ${1})"},
{aten::pow, {"powf(${0}, ${1})", "pow(${0}, ${1})"}},
// alpha
{aten::add, "${0} + ${2}*${1}"},
{aten::sub, "(${0} - ${2}*${1})"},
{aten::rand_like, "uniform(rnd())"},
// where
{aten::where, "(${0} ? ${1} : ${2})"},
};
at::jit::TemplateEnv env;
if (simple_map_ops.find(n->kind()) == simple_map_ops.end()) {
return encodeSpecialRHS(n, env);
} else {
size_t i = 0;
auto outtype = n->output()->type()->expectRef<TensorType>().scalarType();
TORCH_INTERNAL_ASSERT(outtype);
for (auto in : n->inputs()) {
// PyTorch converts (scalar) argument types to result before applying the
// operator e.g. 1.4-torch.tensor(3) = -2
env.s(
c10::to_string(i),
typeCastedValueName(*in->type(), *outtype, valueName(in)));
// Uncasted operands only used for comparison operators
env.s(c10::to_string(i) + "_nocast", valueName(in));
i++;
}
const auto& templ = simple_map_ops.at(n->kind());
const char* str = nullptr;
if (*outtype == at::kFloat) {
str = templ.for_float;
} else {
str = templ.for_double;
}
AT_ASSERT(str);
return format(str, env);
}
}
static void emitIndexingFor(
std::ostream& out,
const std::string& tensor,
const int ndim,
const bool last_is_cont) {
at::jit::TemplateEnv env;
env.s("tensor", tensor);
out << format("IndexType ${tensor}_offset = 0;\n", env);
out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
for (int d = ndim - 1; d >= 0; --d) {
env.d("d", d);
env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
env.s(
"times_stride",
(d < ndim - 1 || !last_is_cont)
? format("* ${tensor}.strides[${d}]", env)
: "");
out << dim_calc.format(env);
if (d > 0) {
out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
}
}
}
static void emitCheckFor(
std::ostream& out,
const std::string& tensor,
const int ndim,
const TensorDesc& desc) {
at::jit::TemplateEnv env;
env.s("tensor", tensor);
env.s("scalar_type", scalarTypeName(desc.scalar_type));
// allocate buffer to load 4
out << format("${scalar_type} ${tensor}_buf[4];\n", env);
// check if last dim is contiguous
if (!desc.lastIsContiguous()) {
out << "flag_vec4 = false;\n";
return;
}
// disable on dtype > 4 bytes for performance
if (at::elementSize(desc.scalar_type) > 4) {
out << "flag_vec4 = false;\n";
return;
}
// last dim size multiple of 4, other dim stride multiple of 4
for (int d = ndim - 1; d >= 0; --d) {
env.d("d", d);
if (d == ndim - 1) {
// last dim stride already checked above at compile time
out << format(
"if(${tensor}.sizes[${d}] % 4 != 0) flag_vec4 = false;\n", env);
} else {
out << format(
"if(${tensor}.strides[${d}] % 4 != 0) flag_vec4 = false;\n", env);
}
}
// pointer aligned
out << format(
"if(((uint64_t) ${tensor}.data) % (4 * sizeof(${scalar_type})) != 0) flag_vec4 = false;\n",
env);
}
// TODO: handle cases where we need to generate > 2^32 element tensors
std::string generateKernel(
const std::string& name,
const Graph& graph,
const std::vector<std::pair<const Value*, const c10::optional<TensorDesc>>>&
inputs,
const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
const bool use_cuda) {
at::jit::TemplateEnv env;
env.s("kernelName", name);
env.s(
"IndexType",
"unsigned int"); // Note: not uint32_t to avoid including cstdint
std::stringstream tensorChecks;
std::stringstream body;
std::stringstream body_vec4;
std::stringstream load;
std::stringstream store;
std::stringstream tensorOffsets;
std::vector<std::string> formals;
std::vector<std::string> argument_loads;
// Lambda for writing arguments
auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
std::string tensor =
"t" +
c10::to_string(
formals.size()); // can't be unique() because Param may be an output
const auto nDim = desc.nDim();
emitCheckFor(tensorChecks, tensor, nDim, desc);
emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
env.s("tensor", tensor);
env.d("nDim", nDim);
env.s("scalar_type", scalarTypeName(desc.scalar_type));
formals.push_back(
format("const TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
argument_loads.push_back(format(
"*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
env));
};
auto emitScalarFormal = [&](const Value* n) {
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
std::string scalar =
"s" +
c10::to_string(
formals.size()); // can't be unique() because Param may be an output
env.d(
"formal_index",
formals.size() +
1); // + 1 because the first argument is the linearIndex
env.s("scalar", scalar);
env.s("scalar_type", variableType(*n->type()));
formals.push_back(format("${scalar_type} ${scalar}", env));
argument_loads.push_back(
format("*static_cast<${scalar_type}*>(args[${formal_index}])", env));
};
// Writes input parameters
for (const auto& input : inputs) {
if (input.second.has_value()) {
emitFormal(input.first, *input.second);
} else {
emitScalarFormal(input.first);
}
}
// Writes output parameters
for (const auto& output : outputs) {
emitFormal(output.first, output.second);
}
// Acquires input values
bool has_half_tensor = false;
bool has_bfloat_tensor = false;
size_t formal_count = 0;
for (const auto& input : inputs) {
auto p = input.first;
env.s("node", valueName(p));
env.d("formal", formal_count++);
// Acquires and converts (if needed) inputs
// Note: conversion from half is only supported for CUDA kernels.
// The conversion immediately converts fp16 inputs to float.
// Access for other types is common to CUDA and CPU kernels.
if (input.second.has_value()) {
const auto is_half = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::Half);
const auto is_bfloat = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::BFloat16);
const auto is_bool = input.second.has_value() &&
((*input.second).scalar_type == at::ScalarType::Bool);
if (is_half) {
AT_ASSERT(use_cuda);
env.s(
"access",
format("__half2float(t${formal}.data[t${formal}_offset])", env));
env.s("access_vec4", format("__half2float(t${formal}_buf[i])", env));
has_half_tensor = true;
} else if (is_bfloat) {
AT_ASSERT(use_cuda);
env.s(
"access",
format(
"__bfloat162float(t${formal}.data[t${formal}_offset])", env));
env.s(
"access_vec4", format("__bfloat162float(t${formal}_buf[i])", env));
has_bfloat_tensor = true;
} else if (use_cuda) {
// No __ldg overload for bool
if (is_bool) {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
} else {
env.s(
"access",
format("__ldg(&t${formal}.data[t${formal}_offset])", env));
}
env.s("access_vec4", format("t${formal}_buf[i]", env));
} else {
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
env.s("access_vec4", format("t${formal}_buf[i]", env));
}
env.s("lhs_type", calcScalarTypeName(input.second->scalar_type));
// load input in vectorized code path
auto ele_size = at::elementSize((*input.second).scalar_type);
if (ele_size == 1) {
env.s(
"load4",
format(
"*(reinterpret_cast<float*>(t${formal}_buf)) = *(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset))",
env));
} else if (ele_size == 2) {
env.s(
"load4",
format(
"*(reinterpret_cast<float2*>(t${formal}_buf)) = *(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset))",
env));
} else if (ele_size == 4) {
env.s(
"load4",
format(
"*(reinterpret_cast<float4*>(t${formal}_buf)) = *(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset))",
env));
} else {
env.s(
"load4",
format(
"for(int i = 0; i<4; i++) t${formal}_buf[i] = t${formal}.data[t${formal}_offset + i]",
env));
}
load << format("${load4};\n", env);
} else {
env.s("access", format("s${formal}", env));
env.s("access_vec4", format("s${formal}", env));
env.s("lhs_type", variableType(*input.first->type()));
}
body << format("${lhs_type} ${node} = ${access};\n", env);
body_vec4 << format("${lhs_type} ${node} = ${access_vec4};\n", env);
}
bool has_random = false;
// Generates code for intermediate nodes
// Note: Concat and Chunk are implicitly generated
// Note: Random number generation is only supported for CUDA kernels.
// Note: Constant None node is ignored and we will handle it in the
// places where the constant None node is used
// Note: No need to iterate over reference as n is a pointer
for (const auto n : graph.nodes()) {
static_assert(std::is_pointer<decltype(n)>::value, "n must be a pointer");
// Note: FusedConcat nodes work by narrowing the output Tensors before the
// kernel runs
if (n->kind() == prim::FusedConcat)
continue;
if (n->kind() == prim::ConstantChunk)
continue;
if (n->mustBeNone())
continue;
if (n->kind() == aten::rand_like) {
AT_ASSERT(use_cuda);
has_random = true;
}
// Always emit double for prim::Constant. This will be narrowed later based
// on either:
// - Tensor-Scalar operator type rules
// - Math function rules
if (n->kind() == prim::Constant) {
const auto val = toIValue(n->output()).value();
std::string rhs;
if (val.isDouble()) {
rhs = scalarValue(val.toDouble());
} else if (val.isBool()) {
rhs = scalarValue(val.toBool());
} else {
AT_ASSERT(val.isInt());
rhs = scalarValue(val.toInt());
}
env.s("node", valueName(n->output()));
env.s("rhs", rhs);
env.s("lhs_type", variableType(*n->output()->type()));
} else {
env.s("node", valueName(n->output()));
env.s("rhs", encodeRHS(n));
env.s("lhs_type", variableType(*n->output()->type()));
}
body << format("${lhs_type} ${node} = ${rhs};\n", env);
body_vec4 << format("${lhs_type} ${node} = ${rhs};\n", env);
}
// Generates writes to output tensors
for (const auto& output : outputs) {
env.d("formal", formal_count++);
env.s("access", format("t${formal}.data[t${formal}_offset]", env));
env.s("access_vec4", format("t${formal}_buf[i]", env));
env.s("node", valueName(output.first));
// Acquires and converts (if needed) outputs
// Note: conversion to half is only supported for CUDA kernels.
const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
const auto is_bfloat =
(output.second.scalar_type == at::ScalarType::BFloat16);
if (is_half) {
AT_ASSERT(use_cuda);
body << format("${access} = __float2half(${node});\n", env);
body_vec4 << format("${access_vec4} = __float2half(${node});\n", env);
has_half_tensor = true;
} else if (is_bfloat) {
AT_ASSERT(use_cuda);
body << format("${access} = __float2bfloat16(${node});\n", env);
body_vec4 << format("${access_vec4} = __float2bfloat16(${node});\n", env);
has_bfloat_tensor = true;
} else {
body << format("${access} = ${node};\n", env);
body_vec4 << format("${access_vec4} = ${node};\n", env);
}
// store output in vectorized code path
auto ele_size = at::elementSize(output.second.scalar_type);
if (ele_size == 1) {
env.s(
"store4",
format(
"*(reinterpret_cast<float*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float*>(t${formal}_buf))",
env));
} else if (ele_size == 2) {
env.s(
"store4",
format(
"*(reinterpret_cast<float2*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float2*>(t${formal}_buf))",
env));
} else if (ele_size == 4) {
env.s(
"store4",
format(
"*(reinterpret_cast<float4*>(t${formal}.data + t${formal}_offset)) = *(reinterpret_cast<float4*>(t${formal}_buf))",
env));
} else {
env.s(
"store4",
format(
"for(int i = 0; i<4; i++) t${formal}.data[t${formal}_offset + i] = t${formal}_buf[i]",
env));
}
store << format("${store4};\n", env);
}
// Includes headers
// Note: CUDA kernels support halfs and random generation, CPU kernels do not
if (has_half_tensor) {
env.s("HalfHeader", cuda::half_support_literal);
} else {
env.s("HalfHeader", "");
}
if (has_bfloat_tensor) {
env.s("BFloat16Header", cuda::bfloat16_support_literal);
} else {
env.s("BFloat16Header", "");
}
if (has_random) {
env.s("RandHeader", cuda::rand_support_literal);
env.s("RandParam", cuda::rand_param);
env.s("RandInit", cuda::rand_init);
} else {
env.s("RandHeader", "");
env.s("RandParam", "");
env.s("RandInit", "");
}
// clang-format on
// Instantiates the CUDA or CPU-specific templates
env.s("tensorOffsets", tensorOffsets.str());
env.s("tensorChecks", tensorChecks.str());
env.s("kernelBody", body.str());
env.s("kernelBody_vec4", body_vec4.str());
env.s("kernelLoad", load.str());
env.s("kernelStore", store.str());
env.v("formals", formals);
env.v("argument_loads", argument_loads);
std::string code_string;
if (use_cuda) {
env.s("type_declarations", cuda::type_declarations_template.format(env));
code_string = cuda::cuda_compilation_unit_template.format(env);
} else {
env.s("type_declarations", cpu::type_declarations_template.format(env));
code_string = cpu::cpu_compilation_unit_template.format(env);
}
if (debugFuser()) {
std::cerr << "fusion code:" << code_string << std::endl;
}
return code_string;
}
} // namespace fuser
} // namespace jit
} // namespace torch