Compare commits

..

1 Commits

Author SHA1 Message Date
6a4bb0f11e Update operator benchmarks README 2025-11-19 07:58:11 +00:00
292 changed files with 2674 additions and 3203 deletions

View File

@ -188,7 +188,7 @@ case "$tag" in
fi
GCC_VERSION=11
VISION=yes
ROCM_VERSION=7.1
ROCM_VERSION=7.0
NINJA_VERSION=1.9.0
TRITON=yes
KATEX=yes

View File

@ -60,16 +60,14 @@ EOF
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev
fi
if [[ $(ver $ROCM_VERSION) -lt $(ver 7.1) ]]; then
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5, removed in ROCm 7.1
# search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
if [[ "x${MIOPENHIPGFX}" = x ]]; then
echo "miopen-hip-gfx package not available" && exit 1
else
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
fi
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5
# search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
if [[ "x${MIOPENHIPGFX}" = x ]]; then
echo "miopen-hip-gfx package not available" && exit 1
else
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
fi
# ROCm 6.0 had a regression where journal_mode was enabled on the kdb files resulting in permission errors at runtime

View File

@ -12,8 +12,8 @@ function do_install() {
rocm_version_nodot=${rocm_version//./}
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
# post merge of https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
rocm_dir="/opt/rocm"

View File

@ -75,11 +75,9 @@ if [[ "$ARCH" == "aarch64" ]]; then
# ARM system libraries
DEPS_LIST+=(
"/usr/lib64/libgfortran.so.5"
"/opt/OpenBLAS/lib/libopenblas.so.0"
)
DEPS_SONAME+=(
"libgfortran.so.5"
"libopenblas.so.0"
)
fi

View File

@ -1 +1 @@
94631807d22c09723dd006f7be5beb649d5f88d0
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a

View File

@ -144,7 +144,7 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
}
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')';
return out;
}

View File

@ -9,7 +9,7 @@ namespace indexing {
const EllipsisIndexType Ellipsis = EllipsisIndexType();
std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
stream << slice.start() << ':' << slice.stop() << ':' << slice.step();
return stream;
}
@ -31,12 +31,12 @@ std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index)
}
std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
stream << "(";
stream << '(';
for (const auto i : c10::irange(tensor_indices.size())) {
stream << tensor_indices[i];
if (i < tensor_indices.size() - 1) stream << ", ";
}
stream << ")";
stream << ')';
return stream;
}

View File

@ -113,7 +113,7 @@ void TensorNames::checkUnique(const char* op_name) const {
std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
out << tensorname.name_ << " (index ";
out << tensorname.origin_idx_ << " of ";
out << tensorname.origin_ << ")";
out << tensorname.origin_ << ')';
return out;
}

View File

@ -13,9 +13,9 @@ std::ostream& operator<<(std::ostream & out, const TensorGeometryArg& t) {
if (t.pos == 0) {
// 0 is distinguished; it usually indicates 'self' or the return
// tensor
out << "'" << t.name << "'";
out << '\'' << t.name << '\'';
} else {
out << "argument #" << t.pos << " '" << t.name << "'";
out << "argument #" << t.pos << " '" << t.name << '\'';
}
return out;
}
@ -154,7 +154,7 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
oss << "Tensor for " << t2 << " is on CPU, ";
}
oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it")
<< " to be on GPU (while checking arguments for " << c << ")";
<< " to be on GPU (while checking arguments for " << c << ')';
TORCH_CHECK(false, oss.str());
}
TORCH_CHECK(
@ -199,7 +199,7 @@ void checkScalarTypes(CheckedFrom c, const TensorArg& t,
i++;
}
oss << "; but got " << t->toString()
<< " instead (while checking arguments for " << c << ")";
<< " instead (while checking arguments for " << c << ')';
TORCH_CHECK(false, oss.str());
}
}

View File

@ -43,8 +43,8 @@ std::string get_mkldnn_version() {
// https://github.com/intel/ideep/issues/29
{
const dnnl_version_t* ver = dnnl_version();
ss << "Intel(R) MKL-DNN v" << ver->major << "." << ver->minor << "." << ver->patch
<< " (Git Hash " << ver->hash << ")";
ss << "Intel(R) MKL-DNN v" << ver->major << '.' << ver->minor << '.' << ver->patch
<< " (Git Hash " << ver->hash << ')';
}
#else
ss << "MKLDNN not found";
@ -81,7 +81,7 @@ std::string get_openmp_version() {
break;
}
if (ver_str) {
ss << " (a.k.a. OpenMP " << ver_str << ")";
ss << " (a.k.a. OpenMP " << ver_str << ')';
}
}
#else
@ -135,38 +135,38 @@ std::string show_config() {
#if defined(__GNUC__)
{
ss << " - GCC " << __GNUC__ << "." << __GNUC_MINOR__ << "\n";
ss << " - GCC " << __GNUC__ << '.' << __GNUC_MINOR__ << '\n';
}
#endif
#if defined(__cplusplus)
{
ss << " - C++ Version: " << __cplusplus << "\n";
ss << " - C++ Version: " << __cplusplus << '\n';
}
#endif
#if defined(__clang_major__)
{
ss << " - clang " << __clang_major__ << "." << __clang_minor__ << "." << __clang_patchlevel__ << "\n";
ss << " - clang " << __clang_major__ << '.' << __clang_minor__ << '.' << __clang_patchlevel__ << '\n';
}
#endif
#if defined(_MSC_VER)
{
ss << " - MSVC " << _MSC_FULL_VER << "\n";
ss << " - MSVC " << _MSC_FULL_VER << '\n';
}
#endif
#if AT_MKL_ENABLED()
ss << " - " << get_mkl_version() << "\n";
ss << " - " << get_mkl_version() << '\n';
#endif
#if AT_MKLDNN_ENABLED()
ss << " - " << get_mkldnn_version() << "\n";
ss << " - " << get_mkldnn_version() << '\n';
#endif
#ifdef _OPENMP
ss << " - " << get_openmp_version() << "\n";
ss << " - " << get_openmp_version() << '\n';
#endif
#if AT_BUILD_WITH_LAPACK()
@ -183,7 +183,7 @@ std::string show_config() {
ss << " - Cross compiling on MacOSX\n";
#endif
ss << " - "<< used_cpu_capability() << "\n";
ss << " - "<< used_cpu_capability() << '\n';
if (hasCUDA()) {
ss << detail::getCUDAHooks().showConfig();
@ -200,10 +200,10 @@ std::string show_config() {
ss << " - Build settings: ";
for (const auto& pair : caffe2::GetBuildOptions()) {
if (!pair.second.empty()) {
ss << pair.first << "=" << pair.second << ", ";
ss << pair.first << '=' << pair.second << ", ";
}
}
ss << "\n";
ss << '\n';
// TODO: do HIP
// TODO: do XLA

View File

@ -209,7 +209,7 @@ struct CodeTemplate {
// to indent correctly in the context.
void emitIndent(std::ostream& out, size_t indent) const {
for ([[maybe_unused]] const auto i : c10::irange(indent)) {
out << " ";
out << ' ';
}
}
void emitStringWithIndents(

View File

@ -10,7 +10,7 @@ std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
if (dimname.type() == NameType::WILDCARD) {
out << "None";
} else {
out << "'" << dimname.symbol().toUnqualString() << "'";
out << '\'' << dimname.symbol().toUnqualString() << '\'';
}
return out;
}

View File

@ -5,7 +5,7 @@
namespace at {
std::ostream& operator<<(std::ostream& out, const Range& range) {
out << "Range[" << range.begin << ", " << range.end << "]";
out << "Range[" << range.begin << ", " << range.end << ']';
return out;
}

View File

@ -71,7 +71,7 @@ void TensorBase::enforce_invariants() {
void TensorBase::print() const {
if (defined()) {
std::cerr << "[" << toString() << " " << sizes() << "]" << '\n';
std::cerr << '[' << toString() << ' ' << sizes() << ']' << '\n';
} else {
std::cerr << "[UndefinedTensor]" << '\n';
}

View File

@ -245,9 +245,6 @@ class TORCH_API TensorBase {
size_t weak_use_count() const noexcept {
return impl_.weak_use_count();
}
bool is_uniquely_owned() const noexcept {
return impl_.is_uniquely_owned();
}
std::string toString() const;

View File

@ -9,8 +9,8 @@ APIVitals VitalsAPI;
std::ostream& operator<<(std::ostream& os, TorchVital const& tv) {
for (const auto& m : tv.attrs) {
os << "[TORCH_VITAL] " << tv.name << "." << m.first << "\t\t "
<< m.second.value << "\n";
os << "[TORCH_VITAL] " << tv.name << '.' << m.first << "\t\t "
<< m.second.value << '\n';
}
return os;
}

View File

@ -100,18 +100,18 @@ inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
// this does match the way things are represented in the schema
inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
out << "(";
out << '(';
bool first = true;
for (const auto& set : aliasInfo.beforeSets()) {
if (first) {
first = false;
} else {
out << "|";
out << '|';
}
out << set.toUnqualString();
}
if (aliasInfo.isWrite()) {
out << "!";
out << '!';
}
if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
out << " -> ";
@ -120,12 +120,12 @@ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
if (first) {
first = false;
} else {
out << "|";
out << '|';
}
out << set.toUnqualString();
}
}
out << ")";
out << ')';
return out;
}
} // namespace c10

View File

@ -198,7 +198,7 @@ inline void swap(Blob& lhs, Blob& rhs) noexcept {
}
inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
return out << "Blob[" << v.TypeName() << "]";
return out << "Blob[" << v.TypeName() << ']';
}
} // namespace caffe2

View File

@ -456,8 +456,8 @@ bool ClassType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
*why_not << "Method on class '" << repr_str()
<< "' (1) is not compatible with interface '"
<< rhs.repr_str() << "' (2)\n"
<< " (1) " << self_method->getSchema() << "\n"
<< " (2) " << schema << "\n";
<< " (1) " << self_method->getSchema() << '\n'
<< " (2) " << schema << '\n';
}
return false;
}

View File

@ -100,7 +100,7 @@ struct TORCH_API ClassType : public NamedType {
std::string repr_str() const override {
std::stringstream ss;
ss << str()
<< " (of Python compilation unit at: " << compilation_unit().get() << ")";
<< " (of Python compilation unit at: " << compilation_unit().get() << ')';
return ss.str();
}

View File

@ -58,12 +58,12 @@ std::string DispatchKeyExtractor::dumpState() const {
std::ostringstream oss;
for (const auto i : c10::irange(c10::utils::bitset::NUM_BITS())) {
if (dispatch_arg_indices_reverse_.get(i)) {
oss << "1";
oss << '1';
} else {
oss << "0";
oss << '0';
}
}
oss << " " << nonFallthroughKeys_ << "\n";
oss << ' ' << nonFallthroughKeys_ << '\n';
return oss.str();
}

View File

@ -69,8 +69,8 @@ private:
void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) {
auto nesting_value = dispatch_trace_nesting_value();
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << ' ';
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << ']' << std::endl;
}
} // namespace detail

View File

@ -570,7 +570,7 @@ void OperatorEntry::checkInvariants() const {
std::string OperatorEntry::listAllDispatchKeys() const {
std::ostringstream str;
str << "[";
str << '[';
bool has_kernels = false;
for (auto k : allDispatchKeysInFullSet()) {
@ -584,7 +584,7 @@ std::string OperatorEntry::listAllDispatchKeys() const {
str << k;
has_kernels = true;
}
str << "]";
str << ']';
return str.str();
}
@ -683,12 +683,12 @@ void OperatorEntry::setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> c
// This WON'T report backend fallbacks.
std::string OperatorEntry::dumpState() const {
std::ostringstream oss;
oss << "name: " << name_ << "\n";
oss << "name: " << name_ << '\n';
if (schema_) {
oss << "schema: " << schema_->schema << "\n";
oss << "debug: " << schema_->debug << "\n";
oss << "schema: " << schema_->schema << '\n';
oss << "debug: " << schema_->debug << '\n';
oss << "alias analysis kind: " << toString(schema_->schema.aliasAnalysis())
<< (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n";
<< (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << '\n';
} else {
oss << "schema: (none)\n";
}

View File

@ -7,7 +7,7 @@
namespace c10 {
void FunctionSchema::dump() const {
std::cout << *this << "\n";
std::cout << *this << '\n';
}
const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type) const {
@ -210,9 +210,9 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
out << schema.name();
if (!schema.overload_name().empty()) {
out << "." << schema.overload_name();
out << '.' << schema.overload_name();
}
out << "(";
out << '(';
bool seen_kwarg_only = false;
for (const auto i : c10::irange(schema.arguments().size())) {
@ -273,7 +273,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
}
if (need_paren) {
out << "(";
out << '(';
}
for (const auto i : c10::irange(returns.size())) {
if (i > 0) {
@ -288,7 +288,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
out << "...";
}
if (need_paren) {
out << ")";
out << ')';
}
return out;
}
@ -471,7 +471,7 @@ bool FunctionSchema::isForwardCompatibleWith(
if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
if (why_not) {
why_not
<< "'" << arguments().at(i).name() << "'"
<< '\'' << arguments().at(i).name() << '\''
<< " is not forward compatible with the older version of the schema";
}
return false;
@ -511,7 +511,7 @@ bool FunctionSchema::isForwardCompatibleWith(
.isForwardCompatibleWith(old.arguments().at(i))) {
if (why_not) {
why_not << "Out argument '"
<< "'" << arguments().at(i).name()
<< '\'' << arguments().at(i).name()
<< " is not FC with the older version of the schema";
}
return false;

View File

@ -571,7 +571,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
if (arg.N()) {
N = std::to_string(*arg.N());
}
out << "[" << N << "]";
out << '[' << N << ']';
} else {
out << unopt_type->str();
}
@ -582,15 +582,15 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
}
if (is_opt) {
out << "?";
out << '?';
}
if (!arg.name().empty()) {
out << " " << arg.name();
out << ' ' << arg.name();
}
if (arg.default_value()) {
out << "=";
out << '=';
if ((type->kind() == c10::TypeKind::StringType ||
unopt_type->kind() == c10::TypeKind::StringType) &&
arg.default_value().value().isString()) {

View File

@ -66,7 +66,7 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
}
std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
out << v.qualifiedClassName() << "." << v.name();
out << v.qualifiedClassName() << '.' << v.name();
return out;
}
@ -526,7 +526,7 @@ std::ostream& printMaybeAnnotatedList(
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
printList(out, the_list.toListRef(), "[", "]", formatter);
out << ")";
out << ')';
return out;
} else {
return printList(out, the_list.toListRef(), "[", "]", formatter);
@ -538,7 +538,7 @@ std::ostream& printDict(
std::ostream& out,
const Dict& v,
const IValueFormatter& formatter) {
out << "{";
out << '{';
bool first = true;
for (const auto& pair : v) {
@ -552,7 +552,7 @@ std::ostream& printDict(
first = false;
}
out << "}";
out << '}';
return out;
}
}
@ -565,8 +565,8 @@ static std::ostream& printMaybeAnnotatedDict(
auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
if (the_dict.toGenericDict().empty() ||
!elementTypeCanBeInferredFromMembers(value_type)) {
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
printDict(out, the_dict.toGenericDict(), formatter) << ")";
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ',';
printDict(out, the_dict.toGenericDict(), formatter) << ')';
} else {
return printDict(out, the_dict.toGenericDict(), formatter);
}
@ -577,7 +577,7 @@ static std::ostream& printComplex(std::ostream & out, const IValue & v) {
c10::complex<double> d = v.toComplexDouble();
IValue real(d.real()), imag(std::abs(d.imag()));
auto sign = d.imag() >= 0 ? '+' : '-';
return out << real << sign << imag << "j";
return out << real << sign << imag << 'j';
}
std::ostream& IValue::repr(
@ -605,9 +605,9 @@ std::ostream& IValue::repr(
if (static_cast<double>(i) == d) {
// -0.0 (signed zero) needs to be parsed as -0.
if (i == 0 && std::signbit(d)) {
return out << "-" << i << ".";
return out << '-' << i << '.';
}
return out << i << ".";
return out << i << '.';
}
}
auto orig_prec = out.precision();
@ -643,20 +643,20 @@ std::ostream& IValue::repr(
device_stream << v.toDevice();
out << "torch.device(";
c10::printQuotedString(out, device_stream.str());
return out << ")";
return out << ')';
}
case IValue::Tag::Generator: {
auto generator = v.toGenerator();
out << "torch.Generator(device=";
c10::printQuotedString(out, generator.device().str());
out << ", seed=" << generator.current_seed() << ")";
out << ", seed=" << generator.current_seed() << ')';
return out;
}
case IValue::Tag::GenericDict:
return printMaybeAnnotatedDict(out, v, formatter);
case IValue::Tag::Enum: {
auto enum_holder = v.toEnumHolder();
return out << enum_holder->qualifiedClassName() << "." <<
return out << enum_holder->qualifiedClassName() << '.' <<
enum_holder->name();
}
case IValue::Tag::Object: {
@ -801,7 +801,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
if (c == FP_NORMAL || c == FP_ZERO) {
int64_t i = static_cast<int64_t>(d);
if (static_cast<double>(i) == d) {
return out << i << ".";
return out << i << '.';
}
}
auto orig_prec = out.precision();
@ -852,7 +852,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return printDict(out, v.toGenericDict(), formatter);
case IValue::Tag::PyObject: {
auto py_obj = v.toPyObject();
return out << "<PyObject at" << py_obj << ">";
return out << "<PyObject at" << py_obj << '>';
}
case IValue::Tag::Generator:
return out << "Generator";
@ -862,22 +862,22 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
// TODO we should attempt to call __str__ if the object defines it.
auto obj = v.toObject();
// print this out the way python would do it
return out << "<" << obj->name() << " object at " << obj.get() << ">";
return out << '<' << obj->name() << " object at " << obj.get() << '>';
}
case IValue::Tag::Enum: {
auto enum_holder = v.toEnumHolder();
return out << "Enum<" << enum_holder->unqualifiedClassName() << "." <<
enum_holder->name() << ">";
return out << "Enum<" << enum_holder->unqualifiedClassName() << '.' <<
enum_holder->name() << '>';
}
}
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << ">";
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << '>';
}
#undef TORCH_FORALL_TAGS
void IValue::dump() const {
std::cout << *this << "\n";
std::cout << *this << '\n';
}
std::shared_ptr<ClassType> ivalue::Object::type() const {
@ -1050,7 +1050,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
std::stringstream err;
err << "Cannot serialize custom bound C++ class";
if (auto qualname = type()->name()) {
err << " " << qualname->qualifiedName();
err << ' ' << qualname->qualifiedName();
}
err << ". Please define serialization methods via def_pickle() for "
"this class.";

View File

@ -211,7 +211,7 @@ struct TORCH_API OptionalType : public UnionType {
std::string str() const override {
std::stringstream ss;
ss << getElementType()->str() << "?";
ss << getElementType()->str() << '?';
return ss.str();
}
@ -240,7 +240,7 @@ struct TORCH_API OptionalType : public UnionType {
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
ss << "Optional[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};
@ -906,7 +906,7 @@ struct TORCH_API ListType
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "List[" << getElementType()->annotation_str(printer) << "]";
ss << "List[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};
@ -946,7 +946,7 @@ struct TORCH_API DictType : public SharedType {
std::string str() const override {
std::stringstream ss;
ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
<< ")";
<< ')';
return ss.str();
}
@ -1018,7 +1018,7 @@ struct TORCH_API FutureType
std::string str() const override {
std::stringstream ss;
ss << "Future(" << getElementType()->str() << ")";
ss << "Future(" << getElementType()->str() << ')';
return ss.str();
}
TypePtr createWithContained(
@ -1041,7 +1041,7 @@ struct TORCH_API FutureType
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
ss << "Future[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};
@ -1060,7 +1060,7 @@ struct TORCH_API AwaitType
std::string str() const override {
std::stringstream ss;
ss << "Await(" << getElementType()->str() << ")";
ss << "Await(" << getElementType()->str() << ')';
return ss.str();
}
TypePtr createWithContained(
@ -1083,7 +1083,7 @@ struct TORCH_API AwaitType
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "Await[" << getElementType()->annotation_str(printer) << "]";
ss << "Await[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};
@ -1102,7 +1102,7 @@ struct TORCH_API RRefType
std::string str() const override {
std::stringstream ss;
ss << "RRef(" << getElementType()->str() << ")";
ss << "RRef(" << getElementType()->str() << ')';
return ss.str();
}
TypePtr createWithContained(
@ -1115,7 +1115,7 @@ struct TORCH_API RRefType
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
std::stringstream ss;
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
ss << "RRef[" << getElementType()->annotation_str(printer) << ']';
return ss.str();
}
};

View File

@ -11,7 +11,7 @@ std::string toString(const OperatorName& opName) {
std::ostream& operator<<(std::ostream& os, const OperatorName& opName) {
os << opName.name;
if (!opName.overload_name.empty()) {
os << "." << opName.overload_name;
os << '.' << opName.overload_name;
}
return os;
}

View File

@ -65,7 +65,7 @@ VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
template <typename T>
std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
out << "(";
out << '(';
if (!vs.size()) {
out << "*)";
return out;
@ -79,10 +79,10 @@ std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
if (v.has_value()) {
out << v.value();
} else {
out << "*";
out << '*';
}
}
out << ")";
out << ')';
return out;
}
@ -105,7 +105,7 @@ std::ostream& operator<<(
}
auto sizes_opt = ss.sizes();
os << "(";
os << '(';
for (size_t i = 0; i < rank_opt.value(); i++) {
if (i > 0) {
os << ", ";
@ -113,10 +113,10 @@ std::ostream& operator<<(
if(sizes_opt.has_value() && sizes_opt.value()[i].is_static()) {
os << sizes_opt.value()[i];
} else {
os << "*";
os << '*';
}
}
os << ")";
os << ')';
return os;
}
@ -131,17 +131,17 @@ std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) {
}
std::ostream& operator<<(std::ostream& os, const Stride& s) {
os << "{";
os << '{';
if (s.stride_index_.has_value()) {
os << *s.stride_index_;
} else {
os << "*";
os << '*';
}
os << ":";
os << ':';
if (s.stride_.has_value()) {
os << *s.stride_;
} else {
os << "*";
os << '*';
}
os << '}';
return os;

View File

@ -67,7 +67,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
bool has_valid_strides_info = ndim > 0 &&
value->strides().isComplete() && value->strides().size() == ndim;
out << "(";
out << '(';
size_t i = 0;
bool symbolic = type_verbosity() == TypeVerbosity::Symbolic;
for (i = 0; i < *ndim; ++i) {
@ -79,7 +79,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
} else if (symbolic) {
out << value->symbolic_sizes().at(i);
} else {
out << "*";
out << '*';
}
}
if (has_valid_strides_info &&
@ -91,7 +91,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
}
out << value->strides()[i].value();
}
out << "]";
out << ']';
}
if (type_verbosity() >= TypeVerbosity::Full) {
if (value->requiresGrad()) {
@ -107,12 +107,12 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << "device=" << *value->device();
}
}
out << ")";
out << ')';
} else {
if (type_verbosity() >= TypeVerbosity::Full) {
size_t i = 0;
if (value->requiresGrad()) {
out << "("
out << '('
<< "requires_grad=" << *value->requiresGrad();
i++;
}
@ -120,7 +120,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << ((i++ > 0) ? ", " : "(") << "device=" << *value->device();
}
if (i > 0) {
out << ")";
out << ')';
}
}
}
@ -133,18 +133,18 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << *prim << "[]";
} else if (t.kind() == TypeKind::OptionalType) {
auto prim = t.castRaw<OptionalType>()->getElementType();
out << *prim << "?";
out << *prim << '?';
} else if(t.kind() == TypeKind::FutureType) {
auto elem = t.castRaw<FutureType>()->getElementType();
out << "Future[" << *elem << "]";
out << "Future[" << *elem << ']';
} else if(t.kind() == TypeKind::RRefType) {
auto elem = t.castRaw<RRefType>()->getElementType();
out << "RRef[" << *elem << "]";
out << "RRef[" << *elem << ']';
} else if(auto tup = t.cast<TupleType>()) {
if (tup->schema()) {
out << "NamedTuple";
}
out << "(";
out << '(';
for(size_t i = 0; i < tup->elements().size(); ++i) {
if(i > 0)
out << ", ";
@ -160,7 +160,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
out << *(tup->elements()[i]);
}
}
out << ")";
out << ')';
} else if (t.kind() == TypeKind::FunctionType) {
out << "Function";
} else {
@ -475,7 +475,7 @@ std::optional<TypePtr> unifyTypeList(
why_not << "Could not unify type list since element " << i << " of type "
<< elements.at(i)->repr_str()
<< " did not match the types before it ("
<< ret_type->repr_str() << ")";
<< ret_type->repr_str() << ')';
return std::nullopt;
}
ret_type = *maybe_unified;
@ -907,13 +907,13 @@ std::string TupleType::str() const {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
ss << name()->qualifiedName();
} else {
ss << "(";
ss << '(';
for(size_t i = 0; i < elements().size(); ++i) {
if(i > 0)
ss << ", ";
ss << elements()[i]->str();
}
ss << ")";
ss << ')';
}
return ss.str();
}
@ -1003,8 +1003,8 @@ bool InterfaceType::isSubTypeImpl(
*why_not << "Method on interface '" << lhs.repr_str()
<< "' (1) is not compatible with interface '"
<< rhs.repr_str() << "' (2)\n"
<< " (1) " << *self_schema << "\n"
<< " (2) " << schema << "\n";
<< " (1) " << *self_schema << '\n'
<< " (2) " << schema << '\n';
return false;
}
return false;
@ -1078,7 +1078,7 @@ SymbolicShape SymbolicShape::merge(const SymbolicShape& other) const {
}
void SymbolicShape::dump() const {
std::cout << *this << "\n";
std::cout << *this << '\n';
}
bool EnumType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {

View File

@ -205,9 +205,9 @@ UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : SharedType
for (const auto i : c10::irange(reference.size())) {
msg << reference[i]->repr_str();
if (i > 0) {
msg << ",";
msg << ',';
}
msg << " ";
msg << ' ';
}
msg << "} has the single type " << types_[0]->repr_str()
<< ". Use the common supertype instead of creating a Union"

View File

@ -80,7 +80,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
}
stream << buf[i];
}
stream << "]";
stream << ']';
return stream;
}

View File

@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
}
stream << buf[i];
}
stream << "]";
stream << ']';
return stream;
}

View File

@ -3,7 +3,6 @@
#include <cstdint>
#include <map>
#include <shared_mutex>
#include <cuda_runtime_api.h>
#include <cusparse.h>
@ -89,13 +88,8 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
struct WorkspaceMapWithMutex {
std::map<std::tuple<void*, void*>, at::DataPtr> map;
std::shared_mutex mutex;
};
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();

View File

@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
// - Comments of @soumith copied from cuDNN handle pool implementation
#ifdef NO_CUDNN_DESTROY_HANDLE
#else
cublasDestroy(handle);
cublasDestroy(handle);
#endif
}
@ -107,27 +107,19 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new WorkspaceMapWithMutex;
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
return instance;
}
void clearCublasWorkspaces() {
{
auto& workspace = cublas_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
{
auto& workspace = cublaslt_handle_stream_to_workspace();
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
workspace.map.clear();
}
cublas_handle_stream_to_workspace().clear();
cublaslt_handle_stream_to_workspace().clear();
}
size_t parseChosenWorkspaceSize() {
@ -241,38 +233,6 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
}
void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto& workspace = cublas_handle_stream_to_workspace();
size_t workspace_size = getChosenWorkspaceSize();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
handle, workspace_it->second.get(), workspace_size));
return;
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
TORCH_CUDABLAS_CHECK(
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
}
}
void* getCUDABlasLtWorkspace() {
#ifndef USE_ROCM
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
@ -281,10 +241,8 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
return workspace_it->second.mutable_get();
}
#endif
@ -292,29 +250,11 @@ void* getCUDABlasLtWorkspace() {
auto stream = c10::cuda::getCurrentCUDAStream();
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto& workspace = cublaslt_handle_stream_to_workspace();
// Fast path: check if workspace already exists
{
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it = workspace.map.find(key);
if (workspace_it != workspace.map.end()) {
return workspace_it->second.mutable_get();
}
}
// Slow path: allocate workspace outside the lock
auto new_workspace = getNewCUDABlasLtWorkspace();
// Insert with lock (double-check in case another thread inserted while we
// were allocating)
{
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
auto workspace_it =
workspace.map.try_emplace(key, std::move(new_workspace)).first;
return workspace_it->second.mutable_get();
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
}
return workspace_it->second.mutable_get();
}
cublasHandle_t getCurrentCUDABlasHandle() {
@ -358,8 +298,13 @@ cublasHandle_t getCurrentCUDABlasHandle() {
// will allocate memory dynamically (even if they're cheap) outside
// PyTorch's CUDA caching allocator. It's possible that CCA used up
// all the memory and cublas's cudaMallocAsync will return OOM
setWorkspaceForHandle(handle, stream);
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
#if !defined(USE_ROCM)
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.

View File

@ -411,16 +411,16 @@ std::string CUDAHooks::showConfig() const {
// HIP_VERSION value format was changed after ROCm v4.2 to include the patch number
if(v < 500) {
// If major=xx, minor=yy then format -> xxyy
oss << (v / 100) << "." << (v % 10);
oss << (v / 100) << '.' << (v % 10);
}
else {
// If major=xx, minor=yy & patch=zzzzz then format -> xxyyzzzzz
oss << (v / 10000000) << "." << (v / 100000 % 100) << "." << (v % 100000);
oss << (v / 10000000) << '.' << (v / 100000 % 100) << '.' << (v % 100000);
}
#else
oss << (v / 1000) << "." << (v / 10 % 100);
oss << (v / 1000) << '.' << (v / 10 % 100);
if (v % 10 != 0) {
oss << "." << (v % 10);
oss << '.' << (v % 10);
}
#endif
};
@ -431,16 +431,16 @@ std::string CUDAHooks::showConfig() const {
oss << " - HIP Runtime ";
#endif
printCudaStyleVersion(runtimeVersion);
oss << "\n";
oss << '\n';
// TODO: Make HIPIFY understand CUDART_VERSION macro
#if !defined(USE_ROCM)
if (runtimeVersion != CUDART_VERSION) {
oss << " - Built with CUDA Runtime ";
printCudaStyleVersion(CUDART_VERSION);
oss << "\n";
oss << '\n';
}
oss << " - NVCC architecture flags: " << NVCC_FLAGS_EXTRA << "\n";
oss << " - NVCC architecture flags: " << NVCC_FLAGS_EXTRA << '\n';
#endif
#if !defined(USE_ROCM)
@ -448,9 +448,9 @@ std::string CUDAHooks::showConfig() const {
auto printCudnnStyleVersion = [&](size_t v) {
oss << (v / 1000) << "." << (v / 100 % 10);
oss << (v / 1000) << '.' << (v / 100 % 10);
if (v % 100 != 0) {
oss << "." << (v % 100);
oss << '.' << (v % 100);
}
};
@ -461,22 +461,22 @@ std::string CUDAHooks::showConfig() const {
if (cudnnCudartVersion != CUDART_VERSION) {
oss << " (built against CUDA ";
printCudaStyleVersion(cudnnCudartVersion);
oss << ")";
oss << ')';
}
oss << "\n";
oss << '\n';
if (cudnnVersion != CUDNN_VERSION) {
oss << " - Built with CuDNN ";
printCudnnStyleVersion(CUDNN_VERSION);
oss << "\n";
oss << '\n';
}
#endif
#else
// TODO: Check if miopen has the functions above and unify
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << "." << MIOPEN_VERSION_MINOR << "." << MIOPEN_VERSION_PATCH << "\n";
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << '.' << MIOPEN_VERSION_MINOR << '.' << MIOPEN_VERSION_PATCH << '\n';
#endif
#if AT_MAGMA_ENABLED()
oss << " - Magma " << MAGMA_VERSION_MAJOR << "." << MAGMA_VERSION_MINOR << "." << MAGMA_VERSION_MICRO << "\n";
oss << " - Magma " << MAGMA_VERSION_MAJOR << '.' << MAGMA_VERSION_MINOR << '.' << MAGMA_VERSION_MICRO << '\n';
#endif
return oss.str();

View File

@ -42,7 +42,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
// The cache key includes all the parameters to generate_code + vec_size + dev_idx
std::stringstream ss;
ss << nInputs << "_" << nOutputs << f;
ss << nInputs << '_' << nOutputs << f;
ss << f_inputs_type_str << compute_type_str << result_type_str;
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
ss << extra_args_types;
@ -144,7 +144,7 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
// The cache key includes all the parameters to generate_code + dev_idx
std::stringstream ss;
ss << nInputs << "_" << nOutputs << f;
ss << nInputs << '_' << nOutputs << f;
ss << f_inputs_type_str << compute_type_str << result_type_str;
ss << contiguous << dynamic_casting;
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);

View File

@ -52,10 +52,10 @@ TuningContext* getTuningContext() {
std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) {
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
if (!blaslog) {
return stream << entry.key_ << "," << entry.time_;
return stream << entry.key_ << ',' << entry.time_;
}
else {
return stream << entry.key_ << "," << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
return stream << entry.key_ << ',' << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
}
}
@ -156,10 +156,10 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
if (isNew) {
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
if (!blaslog) {
untuned_file << op_signature << "," << params_signature << std::endl;
untuned_file << op_signature << ',' << params_signature << std::endl;
}
else {
untuned_file << op_signature << "," << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
untuned_file << op_signature << ',' << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
}
TUNABLE_LOG3("Untuned,", op_signature, ",", params_signature);
}
@ -201,7 +201,7 @@ void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const
if(!file_exists || file_empty) {
for(const auto& [key, val] : validators) {
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
(*realtime_out_) << "Validator," << key << ',' << val << std::endl;
realtime_out_->flush();
}
validators_written_ = true;
@ -219,7 +219,7 @@ void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std
return;
}
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
(*realtime_out_) << op_sig << ',' << param_sig << ',' << result << std::endl;
realtime_out_->flush(); //ensure immediate write to disk
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);

View File

@ -93,31 +93,31 @@ std::string cudnnTypeToString(cudnnDataType_t dtype) {
return "CUDNN_DATA_UINT8x4";
default:
std::ostringstream oss;
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
return oss.str();
}
}
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << '\n';
int nbDims = 0;
int dimA[CUDNN_DIM_MAX];
int strideA[CUDNN_DIM_MAX];
cudnnDataType_t dtype{};
cudnnGetTensorNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &nbDims, dimA, strideA);
out << " type = " << cudnnTypeToString(dtype) << "\n";
out << " nbDims = " << nbDims << "\n";
out << " type = " << cudnnTypeToString(dtype) << '\n';
out << " nbDims = " << nbDims << '\n';
// Read out only nbDims of the arrays!
out << " dimA = ";
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
out << i << ", ";
}
out << "\n";
out << '\n';
out << " strideA = ";
for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
out << i << ", ";
}
out << "\n";
out << '\n';
return out;
}
@ -168,27 +168,27 @@ std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
return "CUDNN_TENSOR_NHWC";
default:
std::ostringstream oss;
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ")";
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ')';
return oss.str();
}
}
std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d) {
out << "FilterDescriptor " << static_cast<void*>(d.desc()) << "\n";
out << "FilterDescriptor " << static_cast<void*>(d.desc()) << '\n';
int nbDims = 0;
int dimA[CUDNN_DIM_MAX];
cudnnDataType_t dtype{};
cudnnTensorFormat_t tformat{};
cudnnGetFilterNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &tformat, &nbDims, dimA);
out << " type = " << cudnnTypeToString(dtype) << "\n";
out << " tensor_format = " << cudnnMemoryFormatToString(tformat) << "\n";
out << " nbDims = " << nbDims << "\n";
out << " type = " << cudnnTypeToString(dtype) << '\n';
out << " tensor_format = " << cudnnMemoryFormatToString(tformat) << '\n';
out << " nbDims = " << nbDims << '\n';
// Read out only nbDims of the arrays!
out << " dimA = ";
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
out << i << ", ";
}
out << "\n";
out << '\n';
return out;
}

View File

@ -346,15 +346,15 @@ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int6
}
std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) {
os << layer.layerId() << ":" << layer.key();
os << layer.layerId() << ':' << layer.key();
return os;
}
std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls) {
os << "DynamicLayerStack[ ";
for (const auto& layer : dls) {
os << layer << " ";
os << layer << ' ';
}
os << "]";
os << ']';
return os;
}

View File

@ -22,7 +22,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
if (batched) {
ss << "Batched[lvl=" << batched->level() << " dim=" << batched->bdim() << ", ";
dumpTensor(ss, batched->value());
ss << "]";
ss << ']';
return;
}
ss << "Tensor" << tensor.sizes();
@ -36,7 +36,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
ss << "dead, ";
}
dumpTensor(ss, wrapped->value());
ss << "]";
ss << ']';
}
void TensorWrapper::refreshMetadata() {

View File

@ -73,32 +73,32 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
return "miopenBFloat16";
default:
std::ostringstream oss;
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
return oss.str();
}
}
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << '\n';
int nbDims = 0;
int dimA[MIOPEN_DIM_MAX];
int strideA[MIOPEN_DIM_MAX];
miopenDataType_t dtype;
miopenGetTensorDescriptorSize(d.desc(), &nbDims);
miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA);
out << " type = " << miopenTypeToString(dtype) << "\n";
out << " nbDims = " << nbDims << "\n";
out << " type = " << miopenTypeToString(dtype) << '\n';
out << " nbDims = " << nbDims << '\n';
// Read out only nbDims of the arrays!
out << " dimA = ";
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
out << i << ", ";
}
out << "\n";
out << '\n';
out << " strideA = ";
for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
out << i << ", ";
}
out << "\n";
out << '\n';
return out;
}

View File

@ -91,7 +91,7 @@ struct OperationInfo : BaseInfo {
std::stringstream kernelStr;
kernelStr << kernelName;
for (const Tensor& tensor : tensors) {
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
kernelStr << ':' << BaseInfo::buildTensorString(tensor, includeBufferId);
}
return kernelStr.str();
}

View File

@ -39,9 +39,9 @@ std::string BaseInfo::buildTensorString(const Tensor& tensor, bool includeBuffer
// see comments for INCLUDE_BUFFER_ID
if (includeBufferId && deviceType == at::kMPS) {
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":" << buffer.retainCount << ")";
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ':' << buffer.retainCount << ')';
}
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
tensorStr << ':' << tensor.scalar_type() << tensor.sizes();
return tensorStr.str();
} else {
return "undefined";

View File

@ -167,7 +167,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co
std::stringstream ss;
ss << arg_name << " should be greater than zero but got (";
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
ss << args.back() << ")" << " (while checking arguments for " << c << ')';
TORCH_CHECK(false, ss.str());
}
}

View File

@ -639,7 +639,7 @@ static std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params)
<< " deterministic = " << params.deterministic
<< " cudnn_enabled = " << params.cudnn_enabled
<< " allow_tf32 = " << params.allow_tf32
<< "}";
<< '}';
return out;
}

View File

@ -3541,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
const int64_t out_features) {
auto M = inp.size(0);
TORCH_CHECK(
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
inp.dtype() == kFloat,
__func__,
" : expect input to be float32 or bfloat16 tensor.");
" : expect input to be 32-bit float tensor.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),

View File

@ -847,7 +847,7 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
<< ", window="; \
if (window.defined()) { \
SS << window.toString() << "{" << window.sizes() << "}"; \
SS << window.toString() << '{' << window.sizes() << '}'; \
} else { \
SS << "None"; \
} \
@ -1046,7 +1046,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const std::optional<int64_
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
<< ", window="; \
if (window.defined()) { \
SS << window.toString() << "{" << window.sizes() << "}"; \
SS << window.toString() << '{' << window.sizes() << '}'; \
} else { \
SS << "None"; \
} \

View File

@ -523,7 +523,7 @@ Tensor _functional_assert_async_msg_cpu(
}
void _print(std::string_view s) {
std::cout << s << "\n";
std::cout << s << '\n';
}
// Sorting-based algorithm for isin(); used when the number of test elements is

View File

@ -8,7 +8,6 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h>
#include <cmath>
#include <c10/util/Unroll.h>
#include <c10/util/irange.h>
@ -794,139 +793,6 @@ bool can_use_kleidiai(
}
#endif
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
size_t m,
size_t n,
size_t k,
const uint16_t* lhs_bf16,
const uint8_t* rhs_qs4cx,
const float* rhs_scales,
uint16_t* dst_bf16,
float scalar_min,
float scalar_max,
const float* bias) {
// Roundup lambda for internal stride calculations
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
// Cast bfloat16 to float32 inline
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
float f;
std::memcpy(&f, &tmp, sizeof(f));
return f;
};
// Cast float32 to bfloat16 inline
auto cast_f32_to_bf16 = [](float f) {
uint32_t bits;
std::memcpy(&bits, &f, sizeof(bits));
return static_cast<uint16_t>(bits >> 16);
};
// Quantization pack lambda (channelwise QA8DX)
auto quant_pack_8bit_channelwise =
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
for (size_t i = 0; i < M; ++i) {
const uint16_t* row_ptr = src_bf16 + i * K;
// find min/max
float mn = FLT_MAX, mx = -FLT_MAX;
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
mn = std::min(mn, v);
mx = std::max(mx, v);
}
float rmin = std::min(0.0f, mn);
float rmax = std::max(0.0f, mx);
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
float recip = scale ? 1.0f / scale : 0.0f;
int32_t zp;
float des_min = rmin * scale;
float des_max = rmax * scale;
float err_min = qmin + des_min;
float err_max = qmax + des_max;
float zp_f =
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
zp_f = std::clamp(zp_f, qmin, qmax);
zp = std::lrintf(zp_f);
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
// store header
*reinterpret_cast<float*>(out_ptr) = recip;
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
out_ptr += sizeof(float) + sizeof(int32_t);
// quantize
for (size_t j = 0; j < K; ++j) {
float v = cast_bf16_to_f32(row_ptr[j]);
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
q = std::clamp(
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
*out_ptr++ = static_cast<int8_t>(q);
}
}
};
// MatMul lambda (MXN x MXK -> MNXK BF16)
auto matmul_kernel = [&](size_t M,
size_t N,
size_t K,
const int8_t* lhs,
const uint8_t* rhs,
const float* scales,
uint16_t* dst,
float lo,
float hi) {
const size_t lhs_stride =
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
const size_t rhs_stride = roundup(K, 2) / 2;
for (size_t i = 0; i < M; ++i) {
const int8_t* lhs_row = lhs + i * lhs_stride;
for (size_t j = 0; j < N; ++j) {
int32_t acc = 0;
const int8_t* lptr = lhs_row;
const uint8_t* rptr = rhs + j * rhs_stride;
float lhs_scale = *reinterpret_cast<const float*>(lptr);
int32_t lhs_off =
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
lptr += sizeof(float) + sizeof(int32_t);
for (size_t t = 0; t < K; ++t) {
int32_t lv = static_cast<int32_t>(lptr[t]);
uint8_t bv = rptr[t / 2];
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
: (static_cast<int32_t>(bv >> 4) - 8);
acc += lv * rv + lhs_off * rv;
}
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
if (bias) {
res += bias[j];
}
res = std::clamp(res, lo, hi);
*dst++ = cast_f32_to_bf16(res);
}
}
};
// allocate and run
std::unique_ptr<int8_t[]> packed(
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
matmul_kernel(
m,
n,
k,
packed.get(),
rhs_qs4cx,
rhs_scales,
dst_bf16,
scalar_min,
scalar_max);
}
/**
* The Int4 quantized weights must be represented as a uint8 tensor
* For matrix multiplication with a weight shape of (N x K)
@ -953,21 +819,21 @@ void dyn_quant_pack_4bit_weight_kernel(
#if AT_KLEIDIAI_ENABLED()
if (can_use_kleidiai(scales_zeros, K, block_size)) {
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
packed_weights.resize_({weight_packed_size});
kleidiai::kai_pack_int4_rhs(
packed_weights, weights, scales_zeros, bias, N, K, block_size);
} else
#endif
{
TORCH_CHECK(
bias.has_value() == 0,
__func__,
" : Bias is unsupported in reference implementation");
packed_weights = packed_weights.to(kFloat);
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
if (bias.has_value()) {
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
}
auto res = at::cat(tensors_to_cat, 0);
auto weight_reshaped = weights.view({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
packed_weights.resize_(res.sizes()).copy_(res);
}
}
@ -981,8 +847,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
const float* rhs_scales_f32,
float* dst_f32,
float scalar_min,
float scalar_max,
const float* bias) {
float scalar_max) {
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
@ -992,9 +857,6 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
// required format for matmul
auto input_quant_pack_8bit_channelwise =
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -1015,8 +877,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
}
// Maximum/minimum int8 values
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -1042,7 +904,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
zero_point0 = std::min(zero_point0, qmax);
// Round to nearest integer
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
@ -1060,8 +922,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = v0_s32 + nudged_zero_point0;
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -1125,10 +987,6 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[n_idx];
}
// Clamp (min-max) operation
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1149,16 +1007,12 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float* rhs_scales_fp32,
float* dst_f32,
float scalar_min,
float scalar_max,
const float* bias) {
float scalar_max) {
// Lambda for LHS quantization
auto lhs_quant_pack = [&](size_t m,
size_t k,
const float* lhs_f32,
int8_t* lhs_qa8dx) {
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
@ -1174,8 +1028,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
min0 = std::min(src0_0, min0);
}
constexpr float qmin = static_cast<float>(kI8Min);
constexpr float qmax = static_cast<float>(kI8Max);
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
const float rmin0 = std::min(0.0f, min0);
const float rmax0 = std::max(0.0f, max0);
@ -1192,7 +1046,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
zero_point0 = std::max(zero_point0, qmin);
zero_point0 = std::min(zero_point0, qmax);
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
@ -1205,8 +1059,9 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
const float src0_0 = src_ptr[k_idx];
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = std::max(
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
static_cast<int32_t>(kI8Min));
std::min(
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
static_cast<int32_t>(INT8_MIN));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
@ -1263,11 +1118,6 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
main_acc = main_acc * lhs_scale;
if (bias) {
main_acc += bias[col_idx];
}
main_acc = std::max(main_acc, scalar_min);
main_acc = std::min(main_acc, scalar_max);
@ -1278,27 +1128,28 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
}
/**
* Dynamic INT4 weight-only MatMul with per-row input quantization.
*
* Execution Flow:
*
* (INT4 Weights + FP Scales [+ optional Bias])
*
* Input (FP32 or BF16) Packed Weight Buffer
* | |
* Row-wise Quantization (INT8) |
* | |
* INT8 Input Activation INT4 Quantized Weights + Scales
* \ /
* \ /
* Quantized Matrix Multiply
* |
* Output Tensor (BF16 or FP32)
*
* Notes:
* - Groupwise kernels expect BF16 scales
* - Channelwise kernels expect FP32 scales
* - Bias is currently unsupported in fallback path
* Dynamic Input Quant 4 bit weights matmul execution flow
(INT4 Weights + FP scales + FP32 Bias)
FP32 Input Packed Buffer
| |
Quantize Cast
to INT8 to INT8
| |
v v
INT8 Input INT8 Weights
\ /
\ /
\ /
INT8 Matrix Multiplication
|
v
FP32 Dequantized and Accumulate in FP32
|
v
FP32 Final Output
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
*/
void dyn_quant_matmul_4bit_kernel(
const Tensor& output,
@ -1310,75 +1161,65 @@ void dyn_quant_matmul_4bit_kernel(
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface internally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);
} else
#endif
{
{
void* input = inp.data_ptr();
void* dst = output.data_ptr();
// Extract weights, sclaes and biases form from packed tensor
const int weights_elements = N * K / 2;
const int scale_elements = N * (K / block_size);
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
float* weight_scales = float32_scales.data_ptr<float>();
void* bias_data = nullptr;
if (bias_elements) {
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
bias_data = float32_bias.data_ptr();
}
// 2 elements of 4 bit weights are packed into 1 uint8 packet
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
// Dispatch to reference kernels
if (inp.scalar_type() == at::kBFloat16) {
// BF16 input, BF16 output
constexpr float BF16_MAX = 3.38953139e+38f;
constexpr float BF16_MIN = -BF16_MAX;
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
M, N, K,
(uint16_t*)input, weights_4bit, weight_scales,
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
}
} else if (inp.scalar_type() == at::kFloat) {
// FP32 input, FP32 output
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M, N, K,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M, N, K, block_size,
(float*)input, weights_4bit, weight_scales,
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
} else {
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
}
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
const auto weights_size = N * K / 2;
// The weights needs to be in uint8_t data type after quantization
auto extracted_weights =
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
auto float32_scales =
(packed_weights.narrow(
0, weights_size, packed_weights.size(0) - weights_size))
.to(kFloat);
uint8_t* rhs_4bit =
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M,
N,
K,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M,
N,
K,
block_size,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else {
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
TORCH_CHECK(
block_size == K || (!(block_size % 32) && !(K % block_size)),
__func__,
": Group size should be multiple 32 or in_features [",
K,
"]. Provided ",
block_size);
}
}
}
}
} // anonymous namespace
}
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)

View File

@ -5,11 +5,69 @@
#include <cuda_bf16.h>
#endif
// ROCm 6.3 is planned to have these functions, but until then here they are.
#if defined(USE_ROCM)
#include <device_functions.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#define ATOMICADD unsafeAtomicAdd
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
union {
__hip_bfloat162_raw bf162_raw;
vec_short2 vs2;
} u{static_cast<__hip_bfloat162_raw>(value)};
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
return static_cast<__hip_bfloat162>(u.bf162_raw);
#else
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
union u_hold {
__hip_bfloat162_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
#if (defined(__gfx942__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
// The api expects an ext_vector_type of half
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
union {
__half2_raw h2r;
vec_fp162 fp16;
} u {static_cast<__half2_raw>(value)};
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
return static_cast<__half2>(u.h2r);
#else
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
union u_hold {
__half2_raw h2r;
unsigned int u32;
};
u_hold old_val, new_val;
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
do {
new_val.h2r = __hadd2(old_val.h2r, value);
} while (!__hip_atomic_compare_exchange_strong(
(unsigned int*)address, &old_val.u32, new_val.u32,
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
return old_val.h2r;
#endif
}
#define ATOMICADD preview_unsafeAtomicAdd
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
#else
#define ATOMICADD atomicAdd

View File

@ -11,7 +11,7 @@ static inline std::ostream& operator<<(std::ostream& out, dim3 dim) {
if (dim.y == 1 && dim.z == 1) {
out << dim.x;
} else {
out << "[" << dim.x << "," << dim.y << "," << dim.z << "]";
out << '[' << dim.x << ',' << dim.y << ',' << dim.z << ']';
}
return out;
}
@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
out << "input_mult=[";
for (int i = 0; i < 3; i++) {
if (i != 0) {
out << ",";
out << ',';
}
out << config.input_mult[i];
}
@ -35,7 +35,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
out << "output_mult=[";
for (int i = 0; i < 2; i++) {
if (i != 0) {
out << ",";
out << ',';
}
out << config.output_mult[i];
}
@ -49,7 +49,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
out << "block=" << config.block() << ", ";
out << "grid=" << config.grid() << ", ";
out << "global_memory_size=" << config.global_memory_size();
out << ")";
out << ')';
return out;
}

View File

@ -1101,19 +1101,6 @@ _scaled_mxfp8_mxfp8(
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
}
void
_check_mxfp4_support() {
#ifndef USE_ROCM
auto dprops = at::cuda::getCurrentDeviceProperties();
// Only on B200 GPUs
TORCH_CHECK_NOT_IMPLEMENTED(
// B200 = 10.0, B300 = 10.3
dprops->major == 10,
"MXFP4 scaling only supported in CUDA for B200/B300"
);
#endif
}
Tensor&
_scaled_mxfp4_mxfp4(
@ -1126,7 +1113,6 @@ _scaled_mxfp4_mxfp4(
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
#else
_check_mxfp4_support();
// Restrictions:
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",

View File

@ -364,9 +364,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
// reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
// stride_output_h + group_count);
// std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << "
// std::cout << "PTRS " << mat_a.data_ptr() << ' ' << mat_b.data_ptr() << "
// "
// << out.data_ptr() << " " << scale_a.data_ptr() << " "
// << out.data_ptr() << ' ' << scale_a.data_ptr() << ' '
// << scale_b.data_ptr() << "\n";
// for (int i = 0; i < group_count; i++) {
// std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n";

View File

@ -1057,14 +1057,14 @@ std::string generate_code(
// TODO these arrays are potentially of the different types, use function
// traits to determine the types
declare_load_arrays << f_inputs_type << " arg" << std::to_string(i)
<< "[" << std::to_string(thread_work_size) << "];\n";
<< '[' << std::to_string(thread_work_size) << "];\n";
}
env.s("declare_load_arrays", declare_load_arrays.str());
std::stringstream declare_store_arrays;
for (int i = 0; i < nOutputs; i++) {
declare_store_arrays << result_type << " out" << std::to_string(i)
<< "[" << std::to_string(thread_work_size) << "];\n";
<< '[' << std::to_string(thread_work_size) << "];\n";
}
env.s("declare_store_arrays", declare_store_arrays.str());
@ -1217,7 +1217,7 @@ std::string generate_code(
for (const auto i : c10::irange(nInputs)){
auto i_string = std::to_string(i);
vector_inputs << "auto * input" << i_string <<
" = reinterpret_cast<const scalar_t*>(data[" << i_string << "+" << nOutputs << "])" <<
" = reinterpret_cast<const scalar_t*>(data[" << i_string << '+' << nOutputs << "])" <<
" + block_work_size * idx;\n";
}
env.s("vector_inputs", vector_inputs.str());
@ -1543,17 +1543,17 @@ NvrtcFunction jit_pwise_function(
// Constructs file path by appending constructed cubin name to cache path
std::stringstream ss;
ss << *cache_dir << "/";
ss << *cache_dir << '/';
ss << kernel_name;
#ifdef USE_ROCM
ss << "_arch" << prop->gcnArchName;
#else
ss << "_arch" << cuda_major << "." << cuda_minor;
ss << "_arch" << cuda_major << '.' << cuda_minor;
#endif
ss << "_nvrtc" << nvrtc_major << "." << nvrtc_minor;
ss << "_nvrtc" << nvrtc_major << '.' << nvrtc_minor;
ss << (compile_to_sass ? "_sass" : "_ptx");
ss << "_" << code.length();
ss << "_" << hash_code;
ss << '_' << code.length();
ss << '_' << hash_code;
file_path = ss.str();
std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};

View File

@ -82,15 +82,15 @@ namespace native {
std::ostream& operator<<(std::ostream& out, const ConvolutionParams& params) {
out << "ConvolutionParams \n"
<< " memory_format = " << params.memory_format << "\n"
<< " data_type = " << cudnnTypeToString(params.dataType) << "\n"
<< " padding = " << ArrayRef<int>{params.padding} << "\n"
<< " stride = " << ArrayRef<int>{params.stride} << "\n"
<< " dilation = " << ArrayRef<int>{params.dilation} << "\n"
<< " groups = " << params.groups << "\n"
<< " memory_format = " << params.memory_format << '\n'
<< " data_type = " << cudnnTypeToString(params.dataType) << '\n'
<< " padding = " << ArrayRef<int>{params.padding} << '\n'
<< " stride = " << ArrayRef<int>{params.stride} << '\n'
<< " dilation = " << ArrayRef<int>{params.dilation} << '\n'
<< " groups = " << params.groups << '\n'
<< " deterministic = " << (params.deterministic ? "true" : "false")
<< "\n"
<< " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n";
<< '\n'
<< " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << '\n';
return out;
}
@ -173,16 +173,16 @@ std::string repro_from_args(const ConvolutionParams& params) {
at::globalContext().float32Precision(
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
at::Float32Precision::TF32)
<< "\n";
<< '\n';
ss << "torch.backends.cudnn.benchmark = "
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
<< pybool(at::globalContext().benchmarkCuDNN()) << '\n';
ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
<< "\n";
<< '\n';
ss << "torch.backends.cudnn.allow_tf32 = " << pybool(params.allow_tf32)
<< "\n";
<< '\n';
ss << "data = torch.randn(" << ArrayRef<int>(params.input_size, dim)
<< ", dtype=" << full_dtype << ", ";
ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n";
ss << "device='cuda', requires_grad=True)" << to_channels_last << '\n';
ss << "net = torch.nn.Conv" << dim - 2 << "d(" << in_channels << ", "
<< out_channels << ", ";
ss << "kernel_size=" << ArrayRef<int>(&params.weight_size[2], dim - 2)
@ -192,7 +192,7 @@ std::string repro_from_args(const ConvolutionParams& params) {
ss << "dilation=" << ArrayRef<int>(params.dilation, dim - 2) << ", ";
ss << "groups=" << params.groups << ")\n";
ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last
<< "\n";
<< '\n';
ss << "out = net(data)\n";
ss << "out.backward(torch.randn_like(out))\n";
ss << "torch.cuda.synchronize()\n\n";

View File

@ -93,11 +93,10 @@ std::ostream& operator<<(std::ostream& out, const ConvolutionArgs& args) {
<< "input: " << args.idesc // already has a trailing newline
<< "output: " << args.odesc // already has a trailing newline
<< "weight: " << args.wdesc // already has a trailing newline
<< "Pointer addresses: "
<< "\n"
<< " input: " << args.input.const_data_ptr() << "\n"
<< " output: " << args.output.const_data_ptr() << "\n"
<< " weight: " << args.weight.const_data_ptr() << "\n";
<< "Pointer addresses: " << '\n'
<< " input: " << args.input.const_data_ptr() << '\n'
<< " output: " << args.output.const_data_ptr() << '\n'
<< " weight: " << args.weight.const_data_ptr() << '\n';
return out;
}

View File

@ -21,27 +21,18 @@ void kai_pack_int4_rhs(
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
if (weight.scalar_type() == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
}
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -72,29 +63,19 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl,
at::ScalarType tensor_dtype) {
const int64_t bl) {
size_t packed_size = n * k;
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
if (tensor_dtype == at::kBFloat16) {
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else {
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
}
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
@ -167,7 +148,8 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -277,7 +259,8 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
@ -337,144 +320,19 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
});
}
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
void kai_quant_pack_lhs_int4_mm(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k) {
// Kernel IDs for GEMM and GEMV
constexpr kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
constexpr kai_kernel_id gemv_id =
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
// Get total threads and select kernel
const int64_t total_threads = at::get_num_threads();
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm() && m > 1) {
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
}
// Thread blocking parameters
const int64_t n_step = kernel_packet.ukernel.get_n_step();
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
const uint8_t* lhs_native_mtx_bf16 =
reinterpret_cast<const uint8_t*>(input.data_ptr());
const uint8_t* rhs_packed_mtx_qs4cx =
reinterpret_cast<const uint8_t*>(weight.data_ptr());
uint8_t* lhs_packed_base = lhs_packed.get();
constexpr int32_t element_size = sizeof(uint16_t);
const size_t lhs_stride = k * element_size;
const size_t dst_stride = n * element_size;
// LHS quantization packing
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
const size_t src_stride = vec_per_thread * lhs_stride;
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
const int64_t m_idx = thread_id * vec_per_thread;
auto lhs_packed_ptr = lhs_packed_base +
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
const int64_t vec_num = (thread_id == num_threads - 1)
? (m - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.kai_run_lhs_quant_pack(
vec_num,
k,
mr,
kr,
sr,
0,
(const uint16_t*)lhs_src_ptr,
lhs_stride,
lhs_packed_ptr);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
lhs_quant_pack(thread_id);
}
});
// Matrix multiplication
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
auto mm = [=, &kernel_packet](int64_t thread_id) {
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(
thread_id * vec_per_thread, k);
auto dst_ptr = dst_act_mtx_bf16 +
kernel_packet.ukernel.get_dst_offset(
0, thread_id * vec_per_thread, dst_stride);
const int64_t vec_num = (thread_id == num_threads - 1)
? (n - vec_per_thread * thread_id)
: vec_per_thread;
kernel_packet.ukernel.run_matmul(
m,
vec_num,
k,
lhs_packed_base,
rhs_packed_ptr,
(uint16_t*)dst_ptr,
dst_stride,
element_size, // dst_stride_col
-FLT_MAX,
FLT_MAX);
};
at::parallel_for(
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
mm(thread_id);
}
});
}
void kai_quant_pack_lhs_int4_mm(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
const auto input_dtype = input.dtype();
if (input_dtype == at::kBFloat16) {
if (cpuinfo_has_arm_bf16()) {
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
}
} else if (input_dtype == at::kFloat) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else {
TORCH_CHECK(
false,
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
}
} else if ((bl % 32 == 0) && (k % bl == 0)) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else if (!(bl % 32) && !(k % bl)) {
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
output, input, weight, m, n, k, bl);
}

View File

@ -25,8 +25,7 @@ void kai_pack_int4_rhs(
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl,
at::ScalarType tensor_dtype = at::kFloat);
const int64_t bl);
/**
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )

View File

@ -36,8 +36,7 @@ void kai_pack_rhs_groupwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
@ -74,8 +73,7 @@ void kai_pack_rhs_channelwise_int4(
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
const auto scales_data = scales.to(kFloat).data_ptr<float>();
const auto scales_data = scales.data_ptr<float>();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
@ -85,8 +83,7 @@ void kai_pack_rhs_channelwise_int4(
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr =
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(

View File

@ -68,39 +68,5 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return channelwise_8bit_4bit_kernels.at(id);
}
// Kernel Mapping - BF16 Channelwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
bf16_channelwise_8bit_4bit_kernels = {
{kai_kernel_id::
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return bf16_channelwise_8bit_4bit_kernels.at(id);
}
} // namespace at::native::kleidiai
#endif

View File

@ -10,32 +10,21 @@
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
namespace at::native::kleidiai {
enum class kai_kernel_id {
// FP32 inputs, 4-bit weights, FP32 output
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
0, // Groupwise 4 bit GEMV
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
1, // Groupwise 4 bit GEMM
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
2, // Channelwise 4 bit GEMV
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
// BF16 inputs, 4-bit weights, BF16 output
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
4, // Channelwise 4-bit GEMV with BF16 input/output
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
5 // Channelwise 4-bit GEMM with BF16 input/output
3 // Channelwise 4 bit GEMM
};
// Channelwise Kernel mapping
@ -77,9 +66,6 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
@ -89,71 +75,12 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
// bf16 Channelwise Kernel mapping
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const void* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
const uint8_t* rhs,
const float* bias,
const float* scale,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
};
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
// Groupwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
@ -198,9 +125,6 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
size_t(*kai_get_lhs_quant_pack_offset)(
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
@ -210,8 +134,7 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(

View File

@ -115,7 +115,7 @@ std::ostream& operator<<(
std::copy(
strides.begin(), strides.end() - 1, std::ostream_iterator<int>(oss, ","));
oss << sizes.back();
output << oss.str() << "}";
output << oss.str() << '}';
return output;
}

View File

@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& out, const ConvParams& params) {
<< " transposed = " << params.transposed
<< " output_padding = " << IntArrayRef{params.output_padding}
<< " groups = " << params.groups << " benchmark = " << params.benchmark
<< " deterministic = " << params.deterministic << "}";
<< " deterministic = " << params.deterministic << '}';
return out;
}

View File

@ -91,30 +91,25 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/Repeat_metallib.h>
#endif
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
template <typename index_t>
void computeRepeatIndices(const index_t* repeat_ptr,
const int64_t* cumsum_ptr,
index_t* result_ptr,
int64_t size,
int64_t result_size) {
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer);
std::string scalar_type;
if (repeat.scalar_type() == kInt) {
if constexpr (std::is_same_v<index_t, int32_t>) {
scalar_type = "int32_t";
} else if (repeat.scalar_type() == kLong) {
} else if constexpr (std::is_same_v<index_t, int64_t>) {
scalar_type = "int64_t";
} else {
TORCH_CHECK(false, "repeats has to be Long or Int tensor");
TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type");
}
if (repeat.size(0) == 0) {
return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}
Tensor repeat_ = repeat.contiguous();
Tensor cumsum = repeat.cumsum(0);
int64_t total = 0;
if (output_size.has_value()) {
total = output_size.value();
} else {
total = cumsum[-1].item<int64_t>();
TORCH_CHECK((repeat >= 0).all().item<uint8_t>(), "repeats can not be negative");
}
auto result = at::empty({total}, repeat.options());
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@ -126,13 +121,20 @@ Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output
getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false);
[computeEncoder setComputePipelineState:pipelineState];
mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0));
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0));
mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size);
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size);
getMPSProfiler().endProfileKernel(pipelineState);
}
});
return result;
}
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
Tensor output;
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
});
return output;
}
} // namespace at::native

View File

@ -5,7 +5,6 @@
#include <ATen/native/Resize.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/mps/OperationUtils.h>
#include <algorithm>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -90,21 +89,13 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor&
auto clamp_shape = clamp_opt->sizes();
auto input_shape = input_t.sizes();
if (num_clamp_dims > num_input_dims) {
auto leading_dims = num_clamp_dims - num_input_dims;
for (int64_t i = 0; i < leading_dims; ++i) {
TORCH_CHECK(clamp_shape[i] == 1,
op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor");
}
}
TORCH_CHECK(num_clamp_dims <= num_input_dims,
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
auto clamp_idx = num_clamp_dims - 1;
auto input_idx = num_input_dims - 1;
auto common_dims = std::min(num_clamp_dims, num_input_dims);
for (int64_t i = 0; i < common_dims; ++i)
for (int i = 0; i < num_clamp_dims; i++)
// One of the indices is allowed to be 1; will be handled by broadcast
TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 ||
input_shape[input_idx - i] == 1,
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
op_name + ": clamp tensor trailing shape must match input tensor")
}
}
@ -145,6 +136,9 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
auto result_type = output_t.scalar_type();
IntArrayRef new_min_shape;
IntArrayRef new_max_shape;
auto num_min_dims = min_opt->dim();
auto num_max_dims = max_opt->dim();
auto num_input_dims = input_t.dim();
@ -152,32 +146,24 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
std::vector<int64_t> new_min_arr(num_input_dims);
std::vector<int64_t> new_max_arr(num_input_dims);
if (has_min && num_min_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
new_min_shape = IntArrayRef(new_min_arr);
}
if (has_max && num_max_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
new_max_shape = IntArrayRef(new_max_arr);
}
Tensor min_opt_tensor;
Tensor max_opt_tensor;
auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref,
int64_t num_clamp_dims,
std::vector<int64_t>& new_shape_storage) -> Tensor {
IntArrayRef clamp_shape = clamp_tensor_ref->sizes();
bool requires_view = false;
if (num_clamp_dims > num_input_dims) {
clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims);
requires_view = true;
} else if (num_clamp_dims < num_input_dims) {
fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape);
clamp_shape = IntArrayRef(new_shape_storage);
requires_view = true;
}
return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref;
};
if (has_min) {
min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr);
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
}
if (has_max) {
max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr);
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
}
@autoreleasepool {

View File

@ -301,12 +301,12 @@ class AvgPoolMicrokernelTester {
ASSERT_NEAR(
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
ASSERT_EQ(
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
}
}
@ -396,12 +396,12 @@ class AvgPoolMicrokernelTester {
ASSERT_NEAR(
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
ASSERT_EQ(
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
}
}

View File

@ -232,7 +232,7 @@ class MaxPoolMicrokernelTester {
ASSERT_EQ(
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
<< "at pixel " << i << ", channel " << k << ", n = " << n()
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
<< "), kc = " << kc();
}
}

View File

@ -17,7 +17,7 @@ inline std::vector<T> _expand_param_if_needed(
std::ostringstream ss;
ss << "expected " << param_name << " to be a single integer value or a "
<< "list of " << expected_dim << " values to match the convolution "
<< "dimensions, but got " << param_name << "=" << list_param;
<< "dimensions, but got " << param_name << '=' << list_param;
TORCH_CHECK(false, ss.str());
} else {
return list_param.vec();

View File

@ -358,9 +358,9 @@ std::string Adapter::stringize() const {
std::string device_type = get_device_type_str(properties.deviceType);
VkPhysicalDeviceLimits limits = properties.limits;
ss << "{" << std::endl;
ss << '{' << std::endl;
ss << " Physical Device Info {" << std::endl;
ss << " apiVersion: " << v_major << "." << v_minor << std::endl;
ss << " apiVersion: " << v_major << '.' << v_minor << std::endl;
ss << " driverversion: " << properties.driverVersion << std::endl;
ss << " deviceType: " << device_type << std::endl;
ss << " deviceName: " << properties.deviceName << std::endl;
@ -371,7 +371,7 @@ std::string Adapter::stringize() const {
#define PRINT_LIMIT_PROP_VEC3(name) \
ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
<< "," << limits.name[1] << "," << limits.name[2] << std::endl;
<< ',' << limits.name[1] << ',' << limits.name[2] << std::endl;
ss << " Physical Device Limits {" << std::endl;
PRINT_LIMIT_PROP(maxImageDimension1D);
@ -425,7 +425,7 @@ std::string Adapter::stringize() const {
;
}
ss << " ]" << std::endl;
ss << "}";
ss << '}';
return ss.str();
}

View File

@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
VK_RESULT_CASE(VK_ERROR_FORMAT_NOT_SUPPORTED)
VK_RESULT_CASE(VK_ERROR_FRAGMENTED_POOL)
default:
out << "VK_ERROR_UNKNOWN (VkResult " << result << ")";
out << "VK_ERROR_UNKNOWN (VkResult " << result << ')';
break;
}
return out;
@ -46,7 +46,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
//
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
out << loc.function << " at " << loc.file << ":" << loc.line;
out << loc.function << " at " << loc.file << ':' << loc.line;
return out;
}
@ -66,7 +66,7 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg)
: msg_(std::move(msg)), source_location_{source_location} {
std::ostringstream oss;
oss << "Exception raised from " << source_location_ << ": ";
oss << "(" << cond << ") is false! ";
oss << '(' << cond << ") is false! ";
oss << msg_;
what_ = oss.str();
}

View File

@ -173,8 +173,8 @@ void QueryPool::extract_results() {
static std::string stringize(const VkExtent3D& extents) {
std::stringstream ss;
ss << "{" << extents.width << ", " << extents.height << ", " << extents.depth
<< "}";
ss << '{' << extents.width << ", " << extents.height << ", " << extents.depth
<< '}';
return ss.str();
}

View File

@ -149,7 +149,7 @@ VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn(
(void)flags;
std::stringstream stream;
stream << layer_prefix << " " << message_code << " " << message << std::endl;
stream << layer_prefix << ' ' << message_code << ' ' << message << std::endl;
const std::string log = stream.str();
std::cout << log;

View File

@ -253,7 +253,7 @@ using vec4 = vec<4u>;
// uvec3 is the type representing tensor extents. Useful for debugging.
inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
return os;
}

View File

@ -61,7 +61,6 @@ list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp

View File

@ -246,7 +246,7 @@ void TestToCFloat() {
void TestToString() {
Tensor b = ones({3, 7}) * .0000001f;
std::stringstream s;
s << b << "\n";
s << b << '\n';
std::string expect = "1e-07 *";
ASSERT_EQ_RESOLVED(s.str().substr(0, expect.size()), expect);
}

View File

@ -1,77 +0,0 @@
#include <gtest/gtest.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGuard.h>
#include <atomic>
#include <thread>
#include <vector>
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
// to verify that the data race fix is working correctly
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
if (!at::cuda::is_available()) {
return;
}
constexpr int num_accessor_threads = 15;
constexpr int num_clear_threads = 5;
constexpr int iterations_per_thread = 50;
std::atomic<bool> stop{false};
std::atomic<int> error_count{0};
std::vector<std::thread> threads;
threads.reserve(num_accessor_threads + num_clear_threads);
// Launch accessor threads
for (int i = 0; i < num_accessor_threads; ++i) {
threads.emplace_back([&stop, &error_count]() {
try {
at::cuda::CUDAGuard device_guard(0);
while (!stop.load(std::memory_order_relaxed)) {
const auto handle = at::cuda::getCurrentCUDABlasHandle();
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
if (handle == nullptr || workspace == nullptr) {
error_count++;
}
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Launch threads that clear workspaces
for (int i = 0; i < num_clear_threads; ++i) {
threads.emplace_back([&error_count]() {
try {
for (int j = 0; j < iterations_per_thread; ++j) {
at::cuda::clearCublasWorkspaces();
std::this_thread::yield();
}
} catch (const std::exception& e) {
error_count++;
}
});
}
// Let them run for a bit
std::this_thread::sleep_for(std::chrono::milliseconds(100));
stop.store(true, std::memory_order_relaxed);
for (auto& thread : threads) {
thread.join();
}
EXPECT_EQ(error_count.load(), 0);
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
c10::cuda::CUDACachingAllocator::init(1);
return RUN_ALL_TESTS();
}

View File

@ -33,7 +33,7 @@ struct Foo {
static void apply(Tensor a, Tensor b) {
scalar_type s = 1;
std::stringstream ss;
ss << "hello, dispatch: " << a.toString() << s << "\n";
ss << "hello, dispatch: " << a.toString() << s << '\n';
auto data = (scalar_type*)a.data_ptr();
(void)data;
}
@ -73,8 +73,8 @@ TEST(TestScalar, TestScalar) {
Scalar bar = 3.0;
Half h = bar.toHalf();
Scalar h2 = h;
cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " "
<< bar.toDouble() << " " << what.isIntegral(false) << "\n";
cout << "H2: " << h2.toDouble() << ' ' << what.toFloat() << ' '
<< bar.toDouble() << ' ' << what.isIntegral(false) << '\n';
auto gen = at::detail::getDefaultCPUGenerator();
{
// See Note [Acquire lock when using random generators]
@ -84,7 +84,7 @@ TEST(TestScalar, TestScalar) {
}
if (at::hasCUDA()) {
auto t2 = zeros({4, 4}, at::kCUDA);
cout << &t2 << "\n";
cout << &t2 << '\n';
}
auto t = ones({4, 4});
@ -129,7 +129,7 @@ TEST(TestScalar, TestScalar) {
std::stringstream ss;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
ASSERT_NO_THROW(
ss << "hello, dispatch" << x.toString() << s << "\n");
ss << "hello, dispatch" << x.toString() << s << '\n');
auto data = (scalar_t*)x.data_ptr();
(void)data;
});

View File

@ -1,5 +1,5 @@
#include <ATen/ATen.h>
int main() {
std::cout << at::ones({3,4}, at::CPU(at::kFloat)) << "\n";
std::cout << at::ones({3,4}, at::CPU(at::kFloat)) << '\n';
}

View File

@ -1828,9 +1828,9 @@ namespace {
#endif
EXPECT_EQ(u16, c10::detail::fp16_ieee_from_fp32_value(f32s[i]))
<< "Test failed for float to uint16 " << f32s[i] << "\n";
<< "Test failed for float to uint16 " << f32s[i] << '\n';
EXPECT_EQ(x, c10::detail::fp16_ieee_to_fp32_value(u16))
<< "Test failed for uint16 to float " << u16 << "\n";
<< "Test failed for uint16 to float " << u16 << '\n';
}
}
TEST(FP8E4M3Test, FP8E4M3ConversionFloat) {
@ -1848,10 +1848,10 @@ namespace {
EXPECT_TRUE(std::isnan(f32));
} else {
EXPECT_EQ(f32, c10::detail::fp8e4m3fn_to_fp32_value(input))
<< "Test failed for u8 to float " << input << "\n";
<< "Test failed for u8 to float " << input << '\n';
}
EXPECT_EQ(u8, c10::detail::fp8e4m3fn_from_fp32_value(f32))
<< "Test failed for float to u8 " << f32 << "\n";
<< "Test failed for float to u8 " << f32 << '\n';
}
}
TEST(FP8E4M3Test, FP8E4M3BinaryAdd) {
@ -2015,10 +2015,10 @@ namespace {
EXPECT_TRUE(std::isnan(f32));
} else {
EXPECT_EQ(f32, c10::detail::fp8e5m2_to_fp32_value(input))
<< "Test failed for u8 to float " << input << "\n";
<< "Test failed for u8 to float " << input << '\n';
}
EXPECT_EQ(u8, c10::detail::fp8e5m2_from_fp32_value(f32))
<< "Test failed for float to u8 " << f32 << "\n";
<< "Test failed for float to u8 " << f32 << '\n';
}
}
TEST(FP8E5M2Test, FP8E5M2BinaryAdd) {

View File

@ -19,7 +19,7 @@ TEST(Vitals, Basic) {
c10::utils::set_env("TORCH_VITAL", "1");
TORCH_VITAL_DEFINE(Testing);
TORCH_VITAL(Testing, Attribute0) << 1;
TORCH_VITAL(Testing, Attribute1) << "1";
TORCH_VITAL(Testing, Attribute1) << '1';
TORCH_VITAL(Testing, Attribute2) << 1.0f;
TORCH_VITAL(Testing, Attribute3) << 1.0;
auto t = at::ones({1, 1});

View File

@ -129,14 +129,14 @@ void showRtol(const at::Tensor& a, const at::Tensor& b) {
std::cout << "Max Diff allowed: " << maxDiff << std::endl;
if (diff.sizes().size() == 2) {
for (const auto y : c10::irange(diff.sizes()[0])) {
std::cout << y << ":";
std::cout << y << ':';
for (const auto x : c10::irange(diff.sizes()[1])) {
float diff_xy = diff[y][x].item<float>();
if (diff_xy > maxDiff) {
std::cout << std::setw(5) << x;
}
else {
std::cout << std::setw(5) << " ";
std::cout << std::setw(5) << ' ';
}
}
std::cout << std::endl;
@ -3276,7 +3276,7 @@ TEST_F(VulkanAPITest, masked_fill_invalidinputs_exceptions) {
void print_shape(const std::vector<int64_t>& shape) {
for (const auto& num : shape) {
std::cout << num << " ";
std::cout << num << ' ';
}
}
@ -3367,7 +3367,7 @@ void test_masked_fill_scalar(
print_shape(tmp_curr_input_shape);
std::cout << "], and mask of shape [";
print_shape(tmp_curr_mask_shape);
std::cout << "]" << std::endl;
std::cout << ']' << std::endl;
}
ASSERT_TRUE(check);
@ -4542,9 +4542,9 @@ void test_softmax(const at::IntArrayRef shape, bool log_softmax = false) {
if (!check) {
std::cout << "Softmax test failed on axis " << dim << "for tensor dims {";
for (uint32_t place = 0; place < shape.size() - 1; place++) {
std::cout << shape[place] << " ";
std::cout << shape[place] << ' ';
}
std::cout << shape.back() << "}" << std::endl;
std::cout << shape.back() << '}' << std::endl;
showRtol(out_cpu, out_vulkan.cpu());
}
ASSERT_TRUE(check);

View File

@ -95,7 +95,7 @@ void showRtol(
std::cout << "Max Diff found is: " << diff.max().item<double>() << std::endl;
if (diff.sizes().size() == 2) {
for (const auto y : c10::irange(diff.sizes()[0])) {
std::cout << y << ":";
std::cout << y << ':';
for (const auto x : c10::irange(diff.sizes()[1])) {
double diff_xy = diff[y][x].item<double>();
if (diff_xy > maxDiff) {
@ -109,7 +109,7 @@ void showRtol(
}
}
} else {
std::cout << std::setw(5) << " ";
std::cout << std::setw(5) << ' ';
}
}
std::cout << std::endl;
@ -148,19 +148,19 @@ using at::native::vulkan::api::utils::ivec4;
using at::native::vulkan::api::utils::vec4;
std::ostream& operator<<(std::ostream& os, const vec4& v) {
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
<< v.data[3u] << ")";
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
<< v.data[3u] << ')';
return os;
}
std::ostream& operator<<(std::ostream& os, const ivec3& v) {
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
return os;
}
std::ostream& operator<<(std::ostream& os, const ivec4& v) {
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
<< v.data[3u] << ")";
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
<< v.data[3u] << ')';
return os;
}
@ -3379,51 +3379,51 @@ bool _test_quantized_linear(
showRtol(out_cpu_dequant, out_vk_to_cpu_dequant);
}
if (xpos != -1 && ypos != -1) {
std::cout << "\nFailure caused on row/col: " << ypos << "/" << xpos
<< "\n";
std::cout << "\nFailure caused on row/col: " << ypos << '/' << xpos
<< '\n';
std::cout << "Input tensor scale: " << scale << " zerop: " << zero_point
<< "\n";
std::cout << "Input tensor row " << ypos << "\n";
<< '\n';
std::cout << "Input tensor row " << ypos << '\n';
for (int i = 0; i < input_cpu.sizes()[1]; i++) {
std::cout << input_cpu[ypos][i].item<double>() << ", ";
}
std::cout << "\n";
std::cout << '\n';
std::cout << "Weight tensor scale: " << w_scale
<< " zerop: " << w_zero_point << "\n";
std::cout << "Weight tensor col " << xpos << "\n";
<< " zerop: " << w_zero_point << '\n';
std::cout << "Weight tensor col " << xpos << '\n';
for (int i = 0; i < weight.sizes()[1]; i++) {
std::cout << weight[xpos][i].item<double>() << ", ";
}
std::cout << "\n";
std::cout << '\n';
std::cout << "Input tensor quantized row " << ypos << " with dtype "
<< (input_quant_dtype_int8 ? "QInt8" : "QUInt8") << "\n";
<< (input_quant_dtype_int8 ? "QInt8" : "QUInt8") << '\n';
for (int i = 0; i < input_cpu.sizes()[1]; i++) {
std::cout << input_cpu_quantized[ypos][i].item<double>() << ", ";
}
std::cout << "\n";
std::cout << '\n';
std::cout << "Weight tensor quantized col " << xpos << " with dtype "
<< (weight_quant_dtype_int8 ? "QInt8" : "QUInt8") << "\n";
<< (weight_quant_dtype_int8 ? "QInt8" : "QUInt8") << '\n';
for (int i = 0; i < weight.sizes()[1]; i++) {
std::cout << weight_cpu_quantized[xpos][i].item<double>() << ", ";
}
std::cout << "\n";
std::cout << '\n';
std::cout << "bias tensor\n";
for (int i = 0; i < bias.sizes()[0]; i++) {
std::cout << bias[i].item<double>() << ", ";
}
std::cout << "\n";
std::cout << '\n';
std::cout << "out_scale: " << out_scale
<< " out_zero_point: " << out_zero_point << "\n";
<< " out_zero_point: " << out_zero_point << '\n';
std::cout << "cpu unmatched output: "
<< out_cpu_dequant[ypos][xpos].item<double>() << "\n";
<< out_cpu_dequant[ypos][xpos].item<double>() << '\n';
std::cout << "vk unmatched output: "
<< out_vk_to_cpu_dequant[ypos][xpos].item<double>() << "\n";
<< out_vk_to_cpu_dequant[ypos][xpos].item<double>() << '\n';
}
}
return check;

View File

@ -10,13 +10,6 @@
...
}
{
ignore_empty_generic_uninitialised_conditional_jump
Memcheck:Cond
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
...
}
{
Cond_cuda
Memcheck:Cond

View File

@ -145,6 +145,64 @@ Run torch.add benchmark with tag 'long':
python -m pt.add_test --tag-filter long
```
## CI Regression Tracking
The operator benchmarks are continuously monitored in CI to track performance regressions across a diverse set of CPU and GPU devices. Two GitHub Actions workflows run these benchmarks on a regular schedule:
### CPU Benchmarks
The [operator_benchmark.yml](../../.github/workflows/operator_benchmark.yml) workflow runs operator benchmarks on CPU devices:
**Devices:**
- x86_64: `linux.12xlarge` (Intel/AMD CPUs)
- aarch64: `linux.arm64.m8g.4xlarge` (ARM64 CPUs)
**Operators Tracked:** All operators in the `pt/` directory with tag : `short`
**Schedule:** Weekly on Sundays at 07:00 UTC
**Test Modes:** `short`, `long`, or `all` (default: `short`)
**Triggers:**
- Scheduled runs (weekly)
- Manual workflow dispatch with configurable test mode
- Push to `ciflow/op-benchmark/*` tags
- Pull requests that modify benchmark files
### GPU Microbenchmarks
The [operator_microbenchmark.yml](../../.github/workflows/operator_microbenchmark.yml) workflow runs operator microbenchmarks on GPU devices:
**CUDA Devices:**
- H100 GPUs (`linux.aws.h100`) - CUDA 12.8, sm_80
- A100 GPUs (`linux.aws.a100`) - CUDA 12.8, sm_80
- B200 GPUs (`linux.dgx.b200`) - CUDA 12.8, sm_100
**ROCm Devices:**
- MI300X GPUs (`linux.rocm.gpu.gfx942.1`) - gfx942
**Operators Tracked in CI:** `matmul`, `mm`, `addmm`, `bmm`, `conv` (with tag `long`)
- Other operators in the `pt/` directory can be run ad-hoc using the workflow dispatch
**Schedule:** Daily at 06:00 UTC
**Performance Dashboard:** [PyTorch Operator Microbenchmark Dashboard](https://hud.pytorch.org/benchmark/v3/dashboard/pytorch_operator_microbenchmark)
**Triggers:**
- Scheduled runs (daily)
- Manual workflow dispatch
- Push to `ciflow/op-benchmark/*` tags
### Running Manual Benchmarks
To trigger a manual run of the benchmarks:
1. Navigate to the [GitHub Actions workflows](https://github.com/pytorch/pytorch/actions)
2. Select either `operator_benchmark` or `operator_microbenchmark`
3. Click "Run workflow" in the top right
4. For CPU benchmarks, optionally select a test mode (`short`, `long`, or `all`)
5. Click "Run workflow" to start the benchmark run
## Adding New Operators to the Benchmark Suite
In the previous sections, we gave several examples to show how to run the already available operators in the benchmark suite. In the following sections, we'll step through the complete flow of adding PyTorch operators to the benchmark suite. Existing benchmarks for operators are in the `pt` directory and we highly recommend putting your new operators in those directories as well.

View File

@ -176,7 +176,7 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
os << k;
first = false;
}
os << ")";
os << ')';
return os;
}

View File

@ -44,7 +44,7 @@ struct C10_API SafePyObject {
(*other.pyinterpreter_)->incref(other.data_);
}
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_);
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
data_ = other.data_;
pyinterpreter_ = other.pyinterpreter_;
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
~SafePyObject() {
if (data_ != nullptr) {
(*pyinterpreter_)->decref(data_);
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
}
}

View File

@ -48,30 +48,6 @@ void warnDeprecatedDataPtr() {
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
}
void StorageImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void StorageImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool StorageImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
// Allowlist verification.
// Only if the devicetype is in the allowlist,

View File

@ -105,12 +105,6 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
data_ptr_.clear();
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
size_t nbytes() const {
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
TORCH_CHECK(!size_bytes_is_heap_allocated_);
@ -376,18 +370,4 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
bool resizable,
std::optional<at::Device> device_opt);
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
} // namespace c10

View File

@ -277,6 +277,7 @@ void TensorImpl::release_resources() {
if (storage_) {
storage_ = {};
}
pyobj_slot_.maybe_destroy_pyobj();
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
@ -988,30 +989,6 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
}
}
void TensorImpl::incref_pyobject() const {
// Because intrusive_ptr incref uses relaxed memory order, we need to
// do an acquire fence to ensure that the kHasPyObject bit was
// observed before the load of the PyObject* below.
// NB: This is a no-op on x86/x86-64
std::atomic_thread_fence(std::memory_order_acquire);
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
}
void TensorImpl::decref_pyobject() const {
PyObject* obj = pyobj_slot_.load_pyobj();
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
}
bool TensorImpl::try_incref_pyobject() const {
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
if (C10_UNLIKELY(!interp)) {
return false;
}
return (*interp)->try_incref(pyobj_slot_);
}
namespace impl {
namespace {

View File

@ -2178,12 +2178,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return &pyobj_slot_;
}
void incref_pyobject() const override final;
void decref_pyobject() const override final;
bool try_incref_pyobject() const override final;
private:
// See NOTE [std::optional operator usage in CUDA]
// We probably don't want to expose this publicly until
@ -3085,19 +3079,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
friend class C10_TensorImpl_Size_Check_Dummy_Class;
};
namespace detail {
#ifndef C10_MOBILE
template <class T>
struct TargetTraits<
T,
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
static constexpr bool can_have_pyobject = true;
};
#endif
} // namespace detail
// Note [TensorImpl size constraints]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Changed the size of TensorImpl? If the size went down, good for

View File

@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& stream, const TensorOptions& options) {
} else {
stream << "(nullopt)";
}
stream << ")";
stream << ')';
return stream;
}

View File

@ -11,11 +11,8 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
void incref(PyObject* pyobj) const override {} // do nothing
void decref(PyObject* pyobj) const override {} // do nothing
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
return false;
}
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
} // do nothing
#define PANIC(m) \
TORCH_INTERNAL_ASSERT( \
@ -23,10 +20,6 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
"attempted to call " #m \
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
size_t refcnt(PyObject* pyobj) const override {
PANIC(refcnt);
}
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
PANIC(detach);
}

View File

@ -18,9 +18,6 @@ namespace c10 {
struct IValue;
class OperatorHandle;
struct TensorImpl;
namespace impl {
struct PyObjectSlot;
} // namespace impl
} // namespace c10
namespace torch::jit {
@ -129,12 +126,9 @@ struct C10_API PyInterpreterVTable {
// Run Py_INCREF on a PyObject.
virtual void incref(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
virtual void decref(PyObject* pyobj) const = 0;
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
// Run Py_REFCNT on a PyObject.
virtual size_t refcnt(PyObject* pyobj) const = 0;
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
// Perform a detach by deferring to the __torch_dispatch__ implementation of
// detach, which will also arrange for the PyObject to get copied in this

View File

@ -0,0 +1,56 @@
#include <c10/core/impl/PyObjectSlot.h>
namespace c10::impl {
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot::~PyObjectSlot() {
maybe_destroy_pyobj();
}
void PyObjectSlot::maybe_destroy_pyobj() {
if (owns_pyobj()) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
(*pyobj_interpreter_.load(std::memory_order_acquire))
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
// then the PyObject holds an owning reference to the tensor).
// So it is OK to clear pyobj_ here as it is impossible for it to
// be used again (modulo weak reference races)
pyobj_ = nullptr; // for safety
}
}
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
return pyobj_interpreter_.load(std::memory_order_acquire);
}
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
}
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter) {
return *interpreter;
}
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
}
bool PyObjectSlot::owns_pyobj() {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
}
void PyObjectSlot::set_owns_pyobj(bool b) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
pyobj_ = reinterpret_cast<PyObject*>(
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
}
} // namespace c10::impl

View File

@ -8,58 +8,117 @@
#include <atomic>
namespace torch::utils {
class PyObjectPreservation;
}
namespace c10::impl {
struct C10_API PyObjectSlot {
public:
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
PyObjectSlot();
~PyObjectSlot();
void maybe_destroy_pyobj();
// Associate the TensorImpl with the specified PyObject, and, if necessary,
// also tag the interpreter.
//
// NB: This lives in a header so that we can inline away the switch on status
//
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
// PyObject if necessary!
void init_pyobj(PyObject* pyobj) {
pyobj_interpreter_.store(
getGlobalPyInterpreter(), std::memory_order_relaxed);
pyobj_ = pyobj;
}
// Query the PyObject interpreter. This may return null if there is no
// interpreter.
PyInterpreter* pyobj_interpreter() const {
return pyobj_interpreter_.load(std::memory_order_acquire);
// interpreter. This is racy!
PyInterpreter* pyobj_interpreter();
PyObject* _unchecked_untagged_pyobj() const;
// Test the interpreter tag. If tagged for the current interpreter, return
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
// returns a nullopt. If it is definitely invalid, raises an error.
//
// If `ignore_hermetic_tls` is false and this function is called from a
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
// context is ignored, allowing you to check the interpreter tag of a
// nonhermetic PyObject from within a hermetic context. This is necessary
// because there are some cases where the deallocator function of a
// nonhermetic PyObject is called from within a hermetic context, so it must
// be properly treated as a nonhermetic PyObject.
//
// NB: this lives in header so that we can avoid actually creating the
// std::optional
// @todo alban: I'm not too sure what's going on here, we can probably delete
// it but it's worthwhile making sure
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
impl::PyInterpreter* interpreter =
pyobj_interpreter_.load(std::memory_order_acquire);
if (interpreter == nullptr) {
return std::nullopt;
}
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
return std::nullopt;
} else {
return _unchecked_untagged_pyobj();
}
}
PyInterpreter& load_pyobj_interpreter() const {
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
TORCH_INTERNAL_ASSERT(
interpreter, "cannot access PyObject for Tensor - no interpreter set");
return *interpreter;
}
PyInterpreter& load_pyobj_interpreter() const;
PyObject* load_pyobj() const {
return pyobj_.load(std::memory_order_acquire);
}
bool owns_pyobj();
void store_pyobj(PyObject* obj) {
pyobj_.store(obj, std::memory_order_release);
}
bool has_unique_reference() const {
PyObject* pyobj = load_pyobj();
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
}
void clear() {
pyobj_.store(nullptr, std::memory_order_relaxed);
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
}
void set_owns_pyobj(bool b);
private:
// This is now always the global interpreter if the PyObject is set.
// Maybe we can remove this field some day...
// This field contains the interpreter tag for this object. See
// Note [Python interpreter tag] for general context
//
// Note [Memory ordering on Python interpreter tag]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// What memory_order do we need when accessing this atomic? We don't
// need a single total modification order (as provided by
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
// transition from -1 to some positive integer and never changes afterwards.
// Because there is only one modification, it trivially already has a total
// modification order (e.g., we don't need fences or locked instructions on
// x86)
//
// In fact, one could make a reasonable argument that relaxed reads are OK,
// due to the presence of external locking (GIL) to ensure that interactions
// with other data structures are still correctly synchronized, so that
// we fall in the "Single-Location Data Structures" case as described in
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
// as I get the same assembly in both cases. So I just use the more
// conservative acquire (which will impede compiler optimizations but I don't
// care)
std::atomic<PyInterpreter*> pyobj_interpreter_;
// The PyObject representing this Tensor or nullptr. Ownership is managed
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
// reference is already dead.
std::atomic<PyObject*> pyobj_;
friend class torch::utils::PyObjectPreservation;
// This field contains a reference to a PyObject representing this Tensor.
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
// PyObject for it and set this field. This field does not have to be
// protected by an atomic as it is only allowed to be accessed when you hold
// the GIL, or during destruction of the tensor.
//
// When a PyObject dies, you are obligated to clear this field
// (otherwise, you will try to use-after-free the pyobj); this currently
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
//
// NB: Ordinarily, this should not be a strong reference, as if the
// PyObject owns the Tensor, this would create a reference cycle.
// However, sometimes this ownership flips. To track who owns
// who, this has a single pointer tag indicating whether or not the
// C++ object owns the PyObject (the common case, zero, means PyObject
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
// or check_pyobj for checked access. See references to PyObject
// resurrection in torch/csrc/autograd/python_variable.cpp
PyObject* pyobj_;
};
} // namespace c10::impl

View File

@ -136,7 +136,7 @@ std::string c10_retrieve_device_side_assertion_info() {
// Something failed, let's talk about that
oss << failures_found
<< " CUDA device-side assertion failures were found on GPU #"
<< device_num << "!" << std::endl;
<< device_num << '!' << std::endl;
if (assertion_data_for_device.assertion_count >
C10_CUDA_DSA_ASSERTION_COUNT) {
oss << "But at least " << assertion_data_for_device.assertion_count
@ -151,17 +151,17 @@ std::string c10_retrieve_device_side_assertion_info() {
oss << "Assertion failure " << i << std::endl;
oss << " GPU assertion failure message = " << self.assertion_msg
<< std::endl;
oss << " File containing assertion = " << self.filename << ":"
oss << " File containing assertion = " << self.filename << ':'
<< self.line_number << std::endl;
oss << " Device function containing assertion = " << self.function_name
<< std::endl;
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ","
<< self.thread_id[1] << "," << self.thread_id[2] << "]" << std::endl;
oss << " Block ID that failed assertion = [" << self.block_id[0] << ","
<< self.block_id[1] << "," << self.block_id[2] << "]" << std::endl;
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ','
<< self.thread_id[1] << ',' << self.thread_id[2] << ']' << std::endl;
oss << " Block ID that failed assertion = [" << self.block_id[0] << ','
<< self.block_id[1] << ',' << self.block_id[2] << ']' << std::endl;
if (launch_info.generation_number == self.caller) {
oss << " File containing kernel launch = "
<< launch_info.launch_filename << ":" << launch_info.launch_linenum
<< launch_info.launch_filename << ':' << launch_info.launch_linenum
<< std::endl;
oss << " Function containing kernel launch = "
<< launch_info.launch_function << std::endl;
@ -175,7 +175,7 @@ std::string c10_retrieve_device_side_assertion_info() {
if (launch_registry.gather_launch_stacktrace) {
oss << "Launch stacktracing disabled." << std::endl;
} else {
oss << "\n" << launch_info.launch_stacktrace << std::endl;
oss << '\n' << launch_info.launch_stacktrace << std::endl;
}
} else {
oss << " CPU launch site info: Unavailable, the circular queue wrapped around. Increase `CUDAKernelLaunchRegistry::max_size`."

View File

@ -435,7 +435,7 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
if (i > 0) {
ASSERT_TRUE(res.find("Unknown") == std::string::npos)
<< i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
<< ")";
<< ')';
} else {
ASSERT_TRUE(res.find("Unknown") == std::string::npos) << i;
}

View File

@ -96,10 +96,10 @@ TEST(HalfConversionTest, TestPorableConversion) {
for (auto x : inputs) {
auto target = c10::detail::fp16_ieee_to_fp32_value(x);
EXPECT_EQ(halfbits2float(x), target)
<< "Test failed for uint16 to float " << x << "\n";
<< "Test failed for uint16 to float " << x << '\n';
EXPECT_EQ(
float2halfbits(target), c10::detail::fp16_ieee_from_fp32_value(target))
<< "Test failed for float to uint16" << target << "\n";
<< "Test failed for float to uint16" << target << '\n';
}
}

View File

@ -98,7 +98,7 @@ struct Noncopyable {
};
std::ostream& operator<<(std::ostream& out, const Noncopyable& nc) {
out << "Noncopyable(" << nc.x << ")";
out << "Noncopyable(" << nc.x << ')';
return out;
}
} // namespace

Some files were not shown because too many files have changed in this diff Show More