mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-20 02:24:54 +08:00
Compare commits
29 Commits
ciflow/tru
...
eqy-patch-
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b312fb0d6 | |||
| 3619883eb8 | |||
| 77acc66df9 | |||
| 95d1df7d4e | |||
| 094e529c64 | |||
| a4c7bf7e8d | |||
| 22ccd44d73 | |||
| 39ebab1dd9 | |||
| 4c152a71ad | |||
| 1b43d6cd4e | |||
| 2b69673bbf | |||
| 2f74916e36 | |||
| 2b5eabc74b | |||
| 9ff95f6835 | |||
| 6fdb974f4a | |||
| 661d1653aa | |||
| 53809f9640 | |||
| 93ddd38ecd | |||
| 5804408f1b | |||
| 99117c1238 | |||
| b9bccec3bc | |||
| ca3aaef66e | |||
| f2e6f94081 | |||
| aa504d4d2a | |||
| d8ce6f8df9 | |||
| 4322354770 | |||
| 363385ad3e | |||
| e2e10753d7 | |||
| 5d99a795f5 |
@ -188,7 +188,7 @@ case "$tag" in
|
||||
fi
|
||||
GCC_VERSION=11
|
||||
VISION=yes
|
||||
ROCM_VERSION=7.0
|
||||
ROCM_VERSION=7.1
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
KATEX=yes
|
||||
|
||||
@ -60,14 +60,16 @@ EOF
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev
|
||||
fi
|
||||
|
||||
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5
|
||||
# search for all unversioned packages
|
||||
# if search fails it will abort this script; use true to avoid case where search fails
|
||||
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
|
||||
if [[ "x${MIOPENHIPGFX}" = x ]]; then
|
||||
echo "miopen-hip-gfx package not available" && exit 1
|
||||
else
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
|
||||
if [[ $(ver $ROCM_VERSION) -lt $(ver 7.1) ]]; then
|
||||
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5, removed in ROCm 7.1
|
||||
# search for all unversioned packages
|
||||
# if search fails it will abort this script; use true to avoid case where search fails
|
||||
MIOPENHIPGFX=$(apt-cache search --names-only miopen-hip-gfx | awk '{print $1}' | grep -F -v . || true)
|
||||
if [[ "x${MIOPENHIPGFX}" = x ]]; then
|
||||
echo "miopen-hip-gfx package not available" && exit 1
|
||||
else
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ${MIOPENHIPGFX}
|
||||
fi
|
||||
fi
|
||||
|
||||
# ROCm 6.0 had a regression where journal_mode was enabled on the kdb files resulting in permission errors at runtime
|
||||
|
||||
@ -12,8 +12,8 @@ function do_install() {
|
||||
|
||||
rocm_version_nodot=${rocm_version//./}
|
||||
|
||||
# post merge of https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
|
||||
# https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
|
||||
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
|
||||
|
||||
rocm_dir="/opt/rocm"
|
||||
|
||||
@ -75,9 +75,11 @@ if [[ "$ARCH" == "aarch64" ]]; then
|
||||
# ARM system libraries
|
||||
DEPS_LIST+=(
|
||||
"/usr/lib64/libgfortran.so.5"
|
||||
"/opt/OpenBLAS/lib/libopenblas.so.0"
|
||||
)
|
||||
DEPS_SONAME+=(
|
||||
"libgfortran.so.5"
|
||||
"libopenblas.so.0"
|
||||
)
|
||||
fi
|
||||
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
07b6cbde121417a70e4dc871adb6d27030e0ce3f
|
||||
ee1a1350eb37804b94334768f328144f058f14e9
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
94631807d22c09723dd006f7be5beb649d5f88d0
|
||||
|
||||
@ -144,7 +144,7 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(BatchDimsRef bdims) {
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
|
||||
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ')';
|
||||
out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ namespace indexing {
|
||||
const EllipsisIndexType Ellipsis = EllipsisIndexType();
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
|
||||
stream << slice.start() << ':' << slice.stop() << ':' << slice.step();
|
||||
stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
|
||||
return stream;
|
||||
}
|
||||
|
||||
@ -31,12 +31,12 @@ std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index)
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
|
||||
stream << '(';
|
||||
stream << "(";
|
||||
for (const auto i : c10::irange(tensor_indices.size())) {
|
||||
stream << tensor_indices[i];
|
||||
if (i < tensor_indices.size() - 1) stream << ", ";
|
||||
}
|
||||
stream << ')';
|
||||
stream << ")";
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ void TensorNames::checkUnique(const char* op_name) const {
|
||||
std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
|
||||
out << tensorname.name_ << " (index ";
|
||||
out << tensorname.origin_idx_ << " of ";
|
||||
out << tensorname.origin_ << ')';
|
||||
out << tensorname.origin_ << ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -13,9 +13,9 @@ std::ostream& operator<<(std::ostream & out, const TensorGeometryArg& t) {
|
||||
if (t.pos == 0) {
|
||||
// 0 is distinguished; it usually indicates 'self' or the return
|
||||
// tensor
|
||||
out << '\'' << t.name << '\'';
|
||||
out << "'" << t.name << "'";
|
||||
} else {
|
||||
out << "argument #" << t.pos << " '" << t.name << '\'';
|
||||
out << "argument #" << t.pos << " '" << t.name << "'";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -154,7 +154,7 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
|
||||
oss << "Tensor for " << t2 << " is on CPU, ";
|
||||
}
|
||||
oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it")
|
||||
<< " to be on GPU (while checking arguments for " << c << ')';
|
||||
<< " to be on GPU (while checking arguments for " << c << ")";
|
||||
TORCH_CHECK(false, oss.str());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
@ -199,7 +199,7 @@ void checkScalarTypes(CheckedFrom c, const TensorArg& t,
|
||||
i++;
|
||||
}
|
||||
oss << "; but got " << t->toString()
|
||||
<< " instead (while checking arguments for " << c << ')';
|
||||
<< " instead (while checking arguments for " << c << ")";
|
||||
TORCH_CHECK(false, oss.str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,8 +43,8 @@ std::string get_mkldnn_version() {
|
||||
// https://github.com/intel/ideep/issues/29
|
||||
{
|
||||
const dnnl_version_t* ver = dnnl_version();
|
||||
ss << "Intel(R) MKL-DNN v" << ver->major << '.' << ver->minor << '.' << ver->patch
|
||||
<< " (Git Hash " << ver->hash << ')';
|
||||
ss << "Intel(R) MKL-DNN v" << ver->major << "." << ver->minor << "." << ver->patch
|
||||
<< " (Git Hash " << ver->hash << ")";
|
||||
}
|
||||
#else
|
||||
ss << "MKLDNN not found";
|
||||
@ -81,7 +81,7 @@ std::string get_openmp_version() {
|
||||
break;
|
||||
}
|
||||
if (ver_str) {
|
||||
ss << " (a.k.a. OpenMP " << ver_str << ')';
|
||||
ss << " (a.k.a. OpenMP " << ver_str << ")";
|
||||
}
|
||||
}
|
||||
#else
|
||||
@ -135,38 +135,38 @@ std::string show_config() {
|
||||
|
||||
#if defined(__GNUC__)
|
||||
{
|
||||
ss << " - GCC " << __GNUC__ << '.' << __GNUC_MINOR__ << '\n';
|
||||
ss << " - GCC " << __GNUC__ << "." << __GNUC_MINOR__ << "\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__cplusplus)
|
||||
{
|
||||
ss << " - C++ Version: " << __cplusplus << '\n';
|
||||
ss << " - C++ Version: " << __cplusplus << "\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(__clang_major__)
|
||||
{
|
||||
ss << " - clang " << __clang_major__ << '.' << __clang_minor__ << '.' << __clang_patchlevel__ << '\n';
|
||||
ss << " - clang " << __clang_major__ << "." << __clang_minor__ << "." << __clang_patchlevel__ << "\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
{
|
||||
ss << " - MSVC " << _MSC_FULL_VER << '\n';
|
||||
ss << " - MSVC " << _MSC_FULL_VER << "\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
#if AT_MKL_ENABLED()
|
||||
ss << " - " << get_mkl_version() << '\n';
|
||||
ss << " - " << get_mkl_version() << "\n";
|
||||
#endif
|
||||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
ss << " - " << get_mkldnn_version() << '\n';
|
||||
ss << " - " << get_mkldnn_version() << "\n";
|
||||
#endif
|
||||
|
||||
#ifdef _OPENMP
|
||||
ss << " - " << get_openmp_version() << '\n';
|
||||
ss << " - " << get_openmp_version() << "\n";
|
||||
#endif
|
||||
|
||||
#if AT_BUILD_WITH_LAPACK()
|
||||
@ -183,7 +183,7 @@ std::string show_config() {
|
||||
ss << " - Cross compiling on MacOSX\n";
|
||||
#endif
|
||||
|
||||
ss << " - "<< used_cpu_capability() << '\n';
|
||||
ss << " - "<< used_cpu_capability() << "\n";
|
||||
|
||||
if (hasCUDA()) {
|
||||
ss << detail::getCUDAHooks().showConfig();
|
||||
@ -200,10 +200,10 @@ std::string show_config() {
|
||||
ss << " - Build settings: ";
|
||||
for (const auto& pair : caffe2::GetBuildOptions()) {
|
||||
if (!pair.second.empty()) {
|
||||
ss << pair.first << '=' << pair.second << ", ";
|
||||
ss << pair.first << "=" << pair.second << ", ";
|
||||
}
|
||||
}
|
||||
ss << '\n';
|
||||
ss << "\n";
|
||||
|
||||
// TODO: do HIP
|
||||
// TODO: do XLA
|
||||
|
||||
@ -209,7 +209,7 @@ struct CodeTemplate {
|
||||
// to indent correctly in the context.
|
||||
void emitIndent(std::ostream& out, size_t indent) const {
|
||||
for ([[maybe_unused]] const auto i : c10::irange(indent)) {
|
||||
out << ' ';
|
||||
out << " ";
|
||||
}
|
||||
}
|
||||
void emitStringWithIndents(
|
||||
|
||||
@ -10,7 +10,7 @@ std::ostream& operator<<(std::ostream& out, const Dimname& dimname) {
|
||||
if (dimname.type() == NameType::WILDCARD) {
|
||||
out << "None";
|
||||
} else {
|
||||
out << '\'' << dimname.symbol().toUnqualString() << '\'';
|
||||
out << "'" << dimname.symbol().toUnqualString() << "'";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
namespace at {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Range& range) {
|
||||
out << "Range[" << range.begin << ", " << range.end << ']';
|
||||
out << "Range[" << range.begin << ", " << range.end << "]";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -71,7 +71,7 @@ void TensorBase::enforce_invariants() {
|
||||
|
||||
void TensorBase::print() const {
|
||||
if (defined()) {
|
||||
std::cerr << '[' << toString() << ' ' << sizes() << ']' << '\n';
|
||||
std::cerr << "[" << toString() << " " << sizes() << "]" << '\n';
|
||||
} else {
|
||||
std::cerr << "[UndefinedTensor]" << '\n';
|
||||
}
|
||||
|
||||
@ -245,6 +245,9 @@ class TORCH_API TensorBase {
|
||||
size_t weak_use_count() const noexcept {
|
||||
return impl_.weak_use_count();
|
||||
}
|
||||
bool is_uniquely_owned() const noexcept {
|
||||
return impl_.is_uniquely_owned();
|
||||
}
|
||||
|
||||
std::string toString() const;
|
||||
|
||||
|
||||
@ -9,8 +9,8 @@ APIVitals VitalsAPI;
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, TorchVital const& tv) {
|
||||
for (const auto& m : tv.attrs) {
|
||||
os << "[TORCH_VITAL] " << tv.name << '.' << m.first << "\t\t "
|
||||
<< m.second.value << '\n';
|
||||
os << "[TORCH_VITAL] " << tv.name << "." << m.first << "\t\t "
|
||||
<< m.second.value << "\n";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -100,18 +100,18 @@ inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) {
|
||||
|
||||
// this does match the way things are represented in the schema
|
||||
inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
|
||||
out << '(';
|
||||
out << "(";
|
||||
bool first = true;
|
||||
for (const auto& set : aliasInfo.beforeSets()) {
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
out << '|';
|
||||
out << "|";
|
||||
}
|
||||
out << set.toUnqualString();
|
||||
}
|
||||
if (aliasInfo.isWrite()) {
|
||||
out << '!';
|
||||
out << "!";
|
||||
}
|
||||
if (aliasInfo.beforeSets() != aliasInfo.afterSets()) {
|
||||
out << " -> ";
|
||||
@ -120,12 +120,12 @@ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
|
||||
if (first) {
|
||||
first = false;
|
||||
} else {
|
||||
out << '|';
|
||||
out << "|";
|
||||
}
|
||||
out << set.toUnqualString();
|
||||
}
|
||||
}
|
||||
out << ')';
|
||||
out << ")";
|
||||
return out;
|
||||
}
|
||||
} // namespace c10
|
||||
|
||||
@ -198,7 +198,7 @@ inline void swap(Blob& lhs, Blob& rhs) noexcept {
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
|
||||
return out << "Blob[" << v.TypeName() << ']';
|
||||
return out << "Blob[" << v.TypeName() << "]";
|
||||
}
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
@ -456,8 +456,8 @@ bool ClassType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
|
||||
*why_not << "Method on class '" << repr_str()
|
||||
<< "' (1) is not compatible with interface '"
|
||||
<< rhs.repr_str() << "' (2)\n"
|
||||
<< " (1) " << self_method->getSchema() << '\n'
|
||||
<< " (2) " << schema << '\n';
|
||||
<< " (1) " << self_method->getSchema() << "\n"
|
||||
<< " (2) " << schema << "\n";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ struct TORCH_API ClassType : public NamedType {
|
||||
std::string repr_str() const override {
|
||||
std::stringstream ss;
|
||||
ss << str()
|
||||
<< " (of Python compilation unit at: " << compilation_unit().get() << ')';
|
||||
<< " (of Python compilation unit at: " << compilation_unit().get() << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -58,12 +58,12 @@ std::string DispatchKeyExtractor::dumpState() const {
|
||||
std::ostringstream oss;
|
||||
for (const auto i : c10::irange(c10::utils::bitset::NUM_BITS())) {
|
||||
if (dispatch_arg_indices_reverse_.get(i)) {
|
||||
oss << '1';
|
||||
oss << "1";
|
||||
} else {
|
||||
oss << '0';
|
||||
oss << "0";
|
||||
}
|
||||
}
|
||||
oss << ' ' << nonFallthroughKeys_ << '\n';
|
||||
oss << " " << nonFallthroughKeys_ << "\n";
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -69,8 +69,8 @@ private:
|
||||
|
||||
void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) {
|
||||
auto nesting_value = dispatch_trace_nesting_value();
|
||||
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << ' ';
|
||||
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << ']' << std::endl;
|
||||
for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
|
||||
std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
|
||||
@ -570,7 +570,7 @@ void OperatorEntry::checkInvariants() const {
|
||||
|
||||
std::string OperatorEntry::listAllDispatchKeys() const {
|
||||
std::ostringstream str;
|
||||
str << '[';
|
||||
str << "[";
|
||||
|
||||
bool has_kernels = false;
|
||||
for (auto k : allDispatchKeysInFullSet()) {
|
||||
@ -584,7 +584,7 @@ std::string OperatorEntry::listAllDispatchKeys() const {
|
||||
str << k;
|
||||
has_kernels = true;
|
||||
}
|
||||
str << ']';
|
||||
str << "]";
|
||||
return str.str();
|
||||
}
|
||||
|
||||
@ -683,12 +683,12 @@ void OperatorEntry::setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> c
|
||||
// This WON'T report backend fallbacks.
|
||||
std::string OperatorEntry::dumpState() const {
|
||||
std::ostringstream oss;
|
||||
oss << "name: " << name_ << '\n';
|
||||
oss << "name: " << name_ << "\n";
|
||||
if (schema_) {
|
||||
oss << "schema: " << schema_->schema << '\n';
|
||||
oss << "debug: " << schema_->debug << '\n';
|
||||
oss << "schema: " << schema_->schema << "\n";
|
||||
oss << "debug: " << schema_->debug << "\n";
|
||||
oss << "alias analysis kind: " << toString(schema_->schema.aliasAnalysis())
|
||||
<< (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << '\n';
|
||||
<< (schema_->schema.isDefaultAliasAnalysisKind() ? " (default)" : "") << "\n";
|
||||
} else {
|
||||
oss << "schema: (none)\n";
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
namespace c10 {
|
||||
|
||||
void FunctionSchema::dump() const {
|
||||
std::cout << *this << '\n';
|
||||
std::cout << *this << "\n";
|
||||
}
|
||||
|
||||
const std::vector<Argument>& FunctionSchema::getCorrectList(SchemaArgType type) const {
|
||||
@ -210,9 +210,9 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
|
||||
out << schema.name();
|
||||
if (!schema.overload_name().empty()) {
|
||||
out << '.' << schema.overload_name();
|
||||
out << "." << schema.overload_name();
|
||||
}
|
||||
out << '(';
|
||||
out << "(";
|
||||
|
||||
bool seen_kwarg_only = false;
|
||||
for (const auto i : c10::irange(schema.arguments().size())) {
|
||||
@ -273,7 +273,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
}
|
||||
|
||||
if (need_paren) {
|
||||
out << '(';
|
||||
out << "(";
|
||||
}
|
||||
for (const auto i : c10::irange(returns.size())) {
|
||||
if (i > 0) {
|
||||
@ -288,7 +288,7 @@ std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
|
||||
out << "...";
|
||||
}
|
||||
if (need_paren) {
|
||||
out << ')';
|
||||
out << ")";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -471,7 +471,7 @@ bool FunctionSchema::isForwardCompatibleWith(
|
||||
if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
|
||||
if (why_not) {
|
||||
why_not
|
||||
<< '\'' << arguments().at(i).name() << '\''
|
||||
<< "'" << arguments().at(i).name() << "'"
|
||||
<< " is not forward compatible with the older version of the schema";
|
||||
}
|
||||
return false;
|
||||
@ -511,7 +511,7 @@ bool FunctionSchema::isForwardCompatibleWith(
|
||||
.isForwardCompatibleWith(old.arguments().at(i))) {
|
||||
if (why_not) {
|
||||
why_not << "Out argument '"
|
||||
<< '\'' << arguments().at(i).name()
|
||||
<< "'" << arguments().at(i).name()
|
||||
<< " is not FC with the older version of the schema";
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -571,7 +571,7 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
||||
if (arg.N()) {
|
||||
N = std::to_string(*arg.N());
|
||||
}
|
||||
out << '[' << N << ']';
|
||||
out << "[" << N << "]";
|
||||
} else {
|
||||
out << unopt_type->str();
|
||||
}
|
||||
@ -582,15 +582,15 @@ inline std::ostream& operator<<(std::ostream& out, const Argument& arg) {
|
||||
}
|
||||
|
||||
if (is_opt) {
|
||||
out << '?';
|
||||
out << "?";
|
||||
}
|
||||
|
||||
if (!arg.name().empty()) {
|
||||
out << ' ' << arg.name();
|
||||
out << " " << arg.name();
|
||||
}
|
||||
|
||||
if (arg.default_value()) {
|
||||
out << '=';
|
||||
out << "=";
|
||||
if ((type->kind() == c10::TypeKind::StringType ||
|
||||
unopt_type->kind() == c10::TypeKind::StringType) &&
|
||||
arg.default_value().value().isString()) {
|
||||
|
||||
@ -66,7 +66,7 @@ bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
|
||||
out << v.qualifiedClassName() << '.' << v.name();
|
||||
out << v.qualifiedClassName() << "." << v.name();
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -526,7 +526,7 @@ std::ostream& printMaybeAnnotatedList(
|
||||
!elementTypeCanBeInferredFromMembers(list_elem_type)) {
|
||||
out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
|
||||
printList(out, the_list.toListRef(), "[", "]", formatter);
|
||||
out << ')';
|
||||
out << ")";
|
||||
return out;
|
||||
} else {
|
||||
return printList(out, the_list.toListRef(), "[", "]", formatter);
|
||||
@ -538,7 +538,7 @@ std::ostream& printDict(
|
||||
std::ostream& out,
|
||||
const Dict& v,
|
||||
const IValueFormatter& formatter) {
|
||||
out << '{';
|
||||
out << "{";
|
||||
|
||||
bool first = true;
|
||||
for (const auto& pair : v) {
|
||||
@ -552,7 +552,7 @@ std::ostream& printDict(
|
||||
first = false;
|
||||
}
|
||||
|
||||
out << '}';
|
||||
out << "}";
|
||||
return out;
|
||||
}
|
||||
}
|
||||
@ -565,8 +565,8 @@ static std::ostream& printMaybeAnnotatedDict(
|
||||
auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
|
||||
if (the_dict.toGenericDict().empty() ||
|
||||
!elementTypeCanBeInferredFromMembers(value_type)) {
|
||||
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ',';
|
||||
printDict(out, the_dict.toGenericDict(), formatter) << ')';
|
||||
out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
|
||||
printDict(out, the_dict.toGenericDict(), formatter) << ")";
|
||||
} else {
|
||||
return printDict(out, the_dict.toGenericDict(), formatter);
|
||||
}
|
||||
@ -577,7 +577,7 @@ static std::ostream& printComplex(std::ostream & out, const IValue & v) {
|
||||
c10::complex<double> d = v.toComplexDouble();
|
||||
IValue real(d.real()), imag(std::abs(d.imag()));
|
||||
auto sign = d.imag() >= 0 ? '+' : '-';
|
||||
return out << real << sign << imag << 'j';
|
||||
return out << real << sign << imag << "j";
|
||||
}
|
||||
|
||||
std::ostream& IValue::repr(
|
||||
@ -605,9 +605,9 @@ std::ostream& IValue::repr(
|
||||
if (static_cast<double>(i) == d) {
|
||||
// -0.0 (signed zero) needs to be parsed as -0.
|
||||
if (i == 0 && std::signbit(d)) {
|
||||
return out << '-' << i << '.';
|
||||
return out << "-" << i << ".";
|
||||
}
|
||||
return out << i << '.';
|
||||
return out << i << ".";
|
||||
}
|
||||
}
|
||||
auto orig_prec = out.precision();
|
||||
@ -643,20 +643,20 @@ std::ostream& IValue::repr(
|
||||
device_stream << v.toDevice();
|
||||
out << "torch.device(";
|
||||
c10::printQuotedString(out, device_stream.str());
|
||||
return out << ')';
|
||||
return out << ")";
|
||||
}
|
||||
case IValue::Tag::Generator: {
|
||||
auto generator = v.toGenerator();
|
||||
out << "torch.Generator(device=";
|
||||
c10::printQuotedString(out, generator.device().str());
|
||||
out << ", seed=" << generator.current_seed() << ')';
|
||||
out << ", seed=" << generator.current_seed() << ")";
|
||||
return out;
|
||||
}
|
||||
case IValue::Tag::GenericDict:
|
||||
return printMaybeAnnotatedDict(out, v, formatter);
|
||||
case IValue::Tag::Enum: {
|
||||
auto enum_holder = v.toEnumHolder();
|
||||
return out << enum_holder->qualifiedClassName() << '.' <<
|
||||
return out << enum_holder->qualifiedClassName() << "." <<
|
||||
enum_holder->name();
|
||||
}
|
||||
case IValue::Tag::Object: {
|
||||
@ -801,7 +801,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
if (c == FP_NORMAL || c == FP_ZERO) {
|
||||
int64_t i = static_cast<int64_t>(d);
|
||||
if (static_cast<double>(i) == d) {
|
||||
return out << i << '.';
|
||||
return out << i << ".";
|
||||
}
|
||||
}
|
||||
auto orig_prec = out.precision();
|
||||
@ -852,7 +852,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
return printDict(out, v.toGenericDict(), formatter);
|
||||
case IValue::Tag::PyObject: {
|
||||
auto py_obj = v.toPyObject();
|
||||
return out << "<PyObject at" << py_obj << '>';
|
||||
return out << "<PyObject at" << py_obj << ">";
|
||||
}
|
||||
case IValue::Tag::Generator:
|
||||
return out << "Generator";
|
||||
@ -862,22 +862,22 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
// TODO we should attempt to call __str__ if the object defines it.
|
||||
auto obj = v.toObject();
|
||||
// print this out the way python would do it
|
||||
return out << '<' << obj->name() << " object at " << obj.get() << '>';
|
||||
return out << "<" << obj->name() << " object at " << obj.get() << ">";
|
||||
}
|
||||
case IValue::Tag::Enum: {
|
||||
auto enum_holder = v.toEnumHolder();
|
||||
return out << "Enum<" << enum_holder->unqualifiedClassName() << '.' <<
|
||||
enum_holder->name() << '>';
|
||||
return out << "Enum<" << enum_holder->unqualifiedClassName() << "." <<
|
||||
enum_holder->name() << ">";
|
||||
}
|
||||
|
||||
}
|
||||
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << '>';
|
||||
return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << ">";
|
||||
}
|
||||
|
||||
#undef TORCH_FORALL_TAGS
|
||||
|
||||
void IValue::dump() const {
|
||||
std::cout << *this << '\n';
|
||||
std::cout << *this << "\n";
|
||||
}
|
||||
|
||||
std::shared_ptr<ClassType> ivalue::Object::type() const {
|
||||
@ -1050,7 +1050,7 @@ c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
|
||||
std::stringstream err;
|
||||
err << "Cannot serialize custom bound C++ class";
|
||||
if (auto qualname = type()->name()) {
|
||||
err << ' ' << qualname->qualifiedName();
|
||||
err << " " << qualname->qualifiedName();
|
||||
}
|
||||
err << ". Please define serialization methods via def_pickle() for "
|
||||
"this class.";
|
||||
|
||||
@ -211,7 +211,7 @@ struct TORCH_API OptionalType : public UnionType {
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << getElementType()->str() << '?';
|
||||
ss << getElementType()->str() << "?";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -240,7 +240,7 @@ struct TORCH_API OptionalType : public UnionType {
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Optional[" << getElementType()->annotation_str(printer) << ']';
|
||||
ss << "Optional[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -906,7 +906,7 @@ struct TORCH_API ListType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "List[" << getElementType()->annotation_str(printer) << ']';
|
||||
ss << "List[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -946,7 +946,7 @@ struct TORCH_API DictType : public SharedType {
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Dict(" << getKeyType()->str() << ", " << getValueType()->str()
|
||||
<< ')';
|
||||
<< ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -1018,7 +1018,7 @@ struct TORCH_API FutureType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Future(" << getElementType()->str() << ')';
|
||||
ss << "Future(" << getElementType()->str() << ")";
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1041,7 +1041,7 @@ struct TORCH_API FutureType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Future[" << getElementType()->annotation_str(printer) << ']';
|
||||
ss << "Future[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -1060,7 +1060,7 @@ struct TORCH_API AwaitType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "Await(" << getElementType()->str() << ')';
|
||||
ss << "Await(" << getElementType()->str() << ")";
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1083,7 +1083,7 @@ struct TORCH_API AwaitType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "Await[" << getElementType()->annotation_str(printer) << ']';
|
||||
ss << "Await[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
@ -1102,7 +1102,7 @@ struct TORCH_API RRefType
|
||||
|
||||
std::string str() const override {
|
||||
std::stringstream ss;
|
||||
ss << "RRef(" << getElementType()->str() << ')';
|
||||
ss << "RRef(" << getElementType()->str() << ")";
|
||||
return ss.str();
|
||||
}
|
||||
TypePtr createWithContained(
|
||||
@ -1115,7 +1115,7 @@ struct TORCH_API RRefType
|
||||
|
||||
std::string annotation_str_impl(const TypePrinter& printer = nullptr) const override {
|
||||
std::stringstream ss;
|
||||
ss << "RRef[" << getElementType()->annotation_str(printer) << ']';
|
||||
ss << "RRef[" << getElementType()->annotation_str(printer) << "]";
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
@ -11,7 +11,7 @@ std::string toString(const OperatorName& opName) {
|
||||
std::ostream& operator<<(std::ostream& os, const OperatorName& opName) {
|
||||
os << opName.name;
|
||||
if (!opName.overload_name.empty()) {
|
||||
os << '.' << opName.overload_name;
|
||||
os << "." << opName.overload_name;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -65,7 +65,7 @@ VaryingShape<T> VaryingShape<T>::merge(const VaryingShape<T>& other) const {
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
|
||||
out << '(';
|
||||
out << "(";
|
||||
if (!vs.size()) {
|
||||
out << "*)";
|
||||
return out;
|
||||
@ -79,10 +79,10 @@ std::ostream& operator<<(std::ostream& out, const VaryingShape<T>& vs) {
|
||||
if (v.has_value()) {
|
||||
out << v.value();
|
||||
} else {
|
||||
out << '*';
|
||||
out << "*";
|
||||
}
|
||||
}
|
||||
out << ')';
|
||||
out << ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -105,7 +105,7 @@ std::ostream& operator<<(
|
||||
}
|
||||
auto sizes_opt = ss.sizes();
|
||||
|
||||
os << '(';
|
||||
os << "(";
|
||||
for (size_t i = 0; i < rank_opt.value(); i++) {
|
||||
if (i > 0) {
|
||||
os << ", ";
|
||||
@ -113,10 +113,10 @@ std::ostream& operator<<(
|
||||
if(sizes_opt.has_value() && sizes_opt.value()[i].is_static()) {
|
||||
os << sizes_opt.value()[i];
|
||||
} else {
|
||||
os << '*';
|
||||
os << "*";
|
||||
}
|
||||
}
|
||||
os << ')';
|
||||
os << ")";
|
||||
|
||||
return os;
|
||||
}
|
||||
@ -131,17 +131,17 @@ std::ostream& operator<<(std::ostream& os, const ShapeSymbol& s) {
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const Stride& s) {
|
||||
os << '{';
|
||||
os << "{";
|
||||
if (s.stride_index_.has_value()) {
|
||||
os << *s.stride_index_;
|
||||
} else {
|
||||
os << '*';
|
||||
os << "*";
|
||||
}
|
||||
os << ':';
|
||||
os << ":";
|
||||
if (s.stride_.has_value()) {
|
||||
os << *s.stride_;
|
||||
} else {
|
||||
os << '*';
|
||||
os << "*";
|
||||
}
|
||||
os << '}';
|
||||
return os;
|
||||
|
||||
@ -67,7 +67,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
bool has_valid_strides_info = ndim > 0 &&
|
||||
value->strides().isComplete() && value->strides().size() == ndim;
|
||||
|
||||
out << '(';
|
||||
out << "(";
|
||||
size_t i = 0;
|
||||
bool symbolic = type_verbosity() == TypeVerbosity::Symbolic;
|
||||
for (i = 0; i < *ndim; ++i) {
|
||||
@ -79,7 +79,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
} else if (symbolic) {
|
||||
out << value->symbolic_sizes().at(i);
|
||||
} else {
|
||||
out << '*';
|
||||
out << "*";
|
||||
}
|
||||
}
|
||||
if (has_valid_strides_info &&
|
||||
@ -91,7 +91,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
}
|
||||
out << value->strides()[i].value();
|
||||
}
|
||||
out << ']';
|
||||
out << "]";
|
||||
}
|
||||
if (type_verbosity() >= TypeVerbosity::Full) {
|
||||
if (value->requiresGrad()) {
|
||||
@ -107,12 +107,12 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << "device=" << *value->device();
|
||||
}
|
||||
}
|
||||
out << ')';
|
||||
out << ")";
|
||||
} else {
|
||||
if (type_verbosity() >= TypeVerbosity::Full) {
|
||||
size_t i = 0;
|
||||
if (value->requiresGrad()) {
|
||||
out << '('
|
||||
out << "("
|
||||
<< "requires_grad=" << *value->requiresGrad();
|
||||
i++;
|
||||
}
|
||||
@ -120,7 +120,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << ((i++ > 0) ? ", " : "(") << "device=" << *value->device();
|
||||
}
|
||||
if (i > 0) {
|
||||
out << ')';
|
||||
out << ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -133,18 +133,18 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << *prim << "[]";
|
||||
} else if (t.kind() == TypeKind::OptionalType) {
|
||||
auto prim = t.castRaw<OptionalType>()->getElementType();
|
||||
out << *prim << '?';
|
||||
out << *prim << "?";
|
||||
} else if(t.kind() == TypeKind::FutureType) {
|
||||
auto elem = t.castRaw<FutureType>()->getElementType();
|
||||
out << "Future[" << *elem << ']';
|
||||
out << "Future[" << *elem << "]";
|
||||
} else if(t.kind() == TypeKind::RRefType) {
|
||||
auto elem = t.castRaw<RRefType>()->getElementType();
|
||||
out << "RRef[" << *elem << ']';
|
||||
out << "RRef[" << *elem << "]";
|
||||
} else if(auto tup = t.cast<TupleType>()) {
|
||||
if (tup->schema()) {
|
||||
out << "NamedTuple";
|
||||
}
|
||||
out << '(';
|
||||
out << "(";
|
||||
for(size_t i = 0; i < tup->elements().size(); ++i) {
|
||||
if(i > 0)
|
||||
out << ", ";
|
||||
@ -160,7 +160,7 @@ std::ostream& operator<<(std::ostream & out, const Type & t) {
|
||||
out << *(tup->elements()[i]);
|
||||
}
|
||||
}
|
||||
out << ')';
|
||||
out << ")";
|
||||
} else if (t.kind() == TypeKind::FunctionType) {
|
||||
out << "Function";
|
||||
} else {
|
||||
@ -475,7 +475,7 @@ std::optional<TypePtr> unifyTypeList(
|
||||
why_not << "Could not unify type list since element " << i << " of type "
|
||||
<< elements.at(i)->repr_str()
|
||||
<< " did not match the types before it ("
|
||||
<< ret_type->repr_str() << ')';
|
||||
<< ret_type->repr_str() << ")";
|
||||
return std::nullopt;
|
||||
}
|
||||
ret_type = *maybe_unified;
|
||||
@ -907,13 +907,13 @@ std::string TupleType::str() const {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
ss << name()->qualifiedName();
|
||||
} else {
|
||||
ss << '(';
|
||||
ss << "(";
|
||||
for(size_t i = 0; i < elements().size(); ++i) {
|
||||
if(i > 0)
|
||||
ss << ", ";
|
||||
ss << elements()[i]->str();
|
||||
}
|
||||
ss << ')';
|
||||
ss << ")";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
@ -1003,8 +1003,8 @@ bool InterfaceType::isSubTypeImpl(
|
||||
*why_not << "Method on interface '" << lhs.repr_str()
|
||||
<< "' (1) is not compatible with interface '"
|
||||
<< rhs.repr_str() << "' (2)\n"
|
||||
<< " (1) " << *self_schema << '\n'
|
||||
<< " (2) " << schema << '\n';
|
||||
<< " (1) " << *self_schema << "\n"
|
||||
<< " (2) " << schema << "\n";
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
@ -1078,7 +1078,7 @@ SymbolicShape SymbolicShape::merge(const SymbolicShape& other) const {
|
||||
}
|
||||
|
||||
void SymbolicShape::dump() const {
|
||||
std::cout << *this << '\n';
|
||||
std::cout << *this << "\n";
|
||||
}
|
||||
|
||||
bool EnumType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
|
||||
|
||||
@ -205,9 +205,9 @@ UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : SharedType
|
||||
for (const auto i : c10::irange(reference.size())) {
|
||||
msg << reference[i]->repr_str();
|
||||
if (i > 0) {
|
||||
msg << ',';
|
||||
msg << ",";
|
||||
}
|
||||
msg << ' ';
|
||||
msg << " ";
|
||||
}
|
||||
msg << "} has the single type " << types_[0]->repr_str()
|
||||
<< ". Use the common supertype instead of creating a Union"
|
||||
|
||||
@ -80,7 +80,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << ']';
|
||||
stream << "]";
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
}
|
||||
stream << buf[i];
|
||||
}
|
||||
stream << ']';
|
||||
stream << "]";
|
||||
return stream;
|
||||
}
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
#include <shared_mutex>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusparse.h>
|
||||
@ -88,8 +89,13 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace();
|
||||
struct WorkspaceMapWithMutex {
|
||||
std::map<std::tuple<void*, void*>, at::DataPtr> map;
|
||||
std::shared_mutex mutex;
|
||||
};
|
||||
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublas_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace();
|
||||
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize();
|
||||
TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace();
|
||||
|
||||
@ -99,7 +99,7 @@ void destroyCublasHandle(cublasHandle_t handle) {
|
||||
// - Comments of @soumith copied from cuDNN handle pool implementation
|
||||
#ifdef NO_CUDNN_DESTROY_HANDLE
|
||||
#else
|
||||
cublasDestroy(handle);
|
||||
cublasDestroy(handle);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -107,19 +107,27 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
|
||||
|
||||
} // namespace
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
WorkspaceMapWithMutex& cublas_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
return instance;
|
||||
}
|
||||
|
||||
std::map<std::tuple<void *, void *>, at::DataPtr>& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
|
||||
WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace() {
|
||||
static auto& instance = *new WorkspaceMapWithMutex;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void clearCublasWorkspaces() {
|
||||
cublas_handle_stream_to_workspace().clear();
|
||||
cublaslt_handle_stream_to_workspace().clear();
|
||||
{
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
{
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
workspace.map.clear();
|
||||
}
|
||||
}
|
||||
|
||||
size_t parseChosenWorkspaceSize() {
|
||||
@ -233,6 +241,38 @@ at::DataPtr getNewCUDABlasLtWorkspace() {
|
||||
return c10::cuda::CUDACachingAllocator::get()->allocate(getCUDABlasLtWorkspaceSize());
|
||||
}
|
||||
|
||||
void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) {
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
|
||||
auto& workspace = cublas_handle_stream_to_workspace();
|
||||
|
||||
size_t workspace_size = getChosenWorkspaceSize();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(
|
||||
handle, workspace_it->second.get(), workspace_size));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first;
|
||||
TORCH_CUDABLAS_CHECK(
|
||||
cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size));
|
||||
}
|
||||
}
|
||||
|
||||
void* getCUDABlasLtWorkspace() {
|
||||
#ifndef USE_ROCM
|
||||
static bool unified = c10::utils::check_env(TORCH_CUBLASLT_UNIFIED_WORKSPACE) == true;
|
||||
@ -241,8 +281,10 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
|
||||
auto& workspace = at::cuda::cublas_handle_stream_to_workspace();
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end());
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
#endif
|
||||
@ -250,11 +292,29 @@ void* getCUDABlasLtWorkspace() {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream();
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublaslt_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublaslt_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublaslt_handle_stream_to_workspace().insert(workspace_it, {key, getNewCUDABlasLtWorkspace()});
|
||||
|
||||
auto& workspace = cublaslt_handle_stream_to_workspace();
|
||||
|
||||
// Fast path: check if workspace already exists
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it = workspace.map.find(key);
|
||||
if (workspace_it != workspace.map.end()) {
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: allocate workspace outside the lock
|
||||
auto new_workspace = getNewCUDABlasLtWorkspace();
|
||||
|
||||
// Insert with lock (double-check in case another thread inserted while we
|
||||
// were allocating)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(workspace.mutex);
|
||||
auto workspace_it =
|
||||
workspace.map.try_emplace(key, std::move(new_workspace)).first;
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
return workspace_it->second.mutable_get();
|
||||
}
|
||||
|
||||
cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
@ -298,13 +358,8 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
// will allocate memory dynamically (even if they're cheap) outside
|
||||
// PyTorch's CUDA caching allocator. It's possible that CCA used up
|
||||
// all the memory and cublas's cudaMallocAsync will return OOM
|
||||
cudaStream_t _stream = stream;
|
||||
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
|
||||
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
|
||||
if (workspace_it == cublas_handle_stream_to_workspace().end()) {
|
||||
workspace_it = cublas_handle_stream_to_workspace().insert(workspace_it, {key, getNewWorkspace()});
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, workspace_it->second.get(), getChosenWorkspaceSize()));
|
||||
setWorkspaceForHandle(handle, stream);
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
|
||||
// FP32 data type calculations based on the value of the allow_tf32 flag.
|
||||
|
||||
@ -411,16 +411,16 @@ std::string CUDAHooks::showConfig() const {
|
||||
// HIP_VERSION value format was changed after ROCm v4.2 to include the patch number
|
||||
if(v < 500) {
|
||||
// If major=xx, minor=yy then format -> xxyy
|
||||
oss << (v / 100) << '.' << (v % 10);
|
||||
oss << (v / 100) << "." << (v % 10);
|
||||
}
|
||||
else {
|
||||
// If major=xx, minor=yy & patch=zzzzz then format -> xxyyzzzzz
|
||||
oss << (v / 10000000) << '.' << (v / 100000 % 100) << '.' << (v % 100000);
|
||||
oss << (v / 10000000) << "." << (v / 100000 % 100) << "." << (v % 100000);
|
||||
}
|
||||
#else
|
||||
oss << (v / 1000) << '.' << (v / 10 % 100);
|
||||
oss << (v / 1000) << "." << (v / 10 % 100);
|
||||
if (v % 10 != 0) {
|
||||
oss << '.' << (v % 10);
|
||||
oss << "." << (v % 10);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
@ -431,16 +431,16 @@ std::string CUDAHooks::showConfig() const {
|
||||
oss << " - HIP Runtime ";
|
||||
#endif
|
||||
printCudaStyleVersion(runtimeVersion);
|
||||
oss << '\n';
|
||||
oss << "\n";
|
||||
|
||||
// TODO: Make HIPIFY understand CUDART_VERSION macro
|
||||
#if !defined(USE_ROCM)
|
||||
if (runtimeVersion != CUDART_VERSION) {
|
||||
oss << " - Built with CUDA Runtime ";
|
||||
printCudaStyleVersion(CUDART_VERSION);
|
||||
oss << '\n';
|
||||
oss << "\n";
|
||||
}
|
||||
oss << " - NVCC architecture flags: " << NVCC_FLAGS_EXTRA << '\n';
|
||||
oss << " - NVCC architecture flags: " << NVCC_FLAGS_EXTRA << "\n";
|
||||
#endif
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
@ -448,9 +448,9 @@ std::string CUDAHooks::showConfig() const {
|
||||
|
||||
|
||||
auto printCudnnStyleVersion = [&](size_t v) {
|
||||
oss << (v / 1000) << '.' << (v / 100 % 10);
|
||||
oss << (v / 1000) << "." << (v / 100 % 10);
|
||||
if (v % 100 != 0) {
|
||||
oss << '.' << (v % 100);
|
||||
oss << "." << (v % 100);
|
||||
}
|
||||
};
|
||||
|
||||
@ -461,22 +461,22 @@ std::string CUDAHooks::showConfig() const {
|
||||
if (cudnnCudartVersion != CUDART_VERSION) {
|
||||
oss << " (built against CUDA ";
|
||||
printCudaStyleVersion(cudnnCudartVersion);
|
||||
oss << ')';
|
||||
oss << ")";
|
||||
}
|
||||
oss << '\n';
|
||||
oss << "\n";
|
||||
if (cudnnVersion != CUDNN_VERSION) {
|
||||
oss << " - Built with CuDNN ";
|
||||
printCudnnStyleVersion(CUDNN_VERSION);
|
||||
oss << '\n';
|
||||
oss << "\n";
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
// TODO: Check if miopen has the functions above and unify
|
||||
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << '.' << MIOPEN_VERSION_MINOR << '.' << MIOPEN_VERSION_PATCH << '\n';
|
||||
oss << " - MIOpen " << MIOPEN_VERSION_MAJOR << "." << MIOPEN_VERSION_MINOR << "." << MIOPEN_VERSION_PATCH << "\n";
|
||||
#endif
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
oss << " - Magma " << MAGMA_VERSION_MAJOR << '.' << MAGMA_VERSION_MINOR << '.' << MAGMA_VERSION_MICRO << '\n';
|
||||
oss << " - Magma " << MAGMA_VERSION_MAJOR << "." << MAGMA_VERSION_MINOR << "." << MAGMA_VERSION_MICRO << "\n";
|
||||
#endif
|
||||
|
||||
return oss.str();
|
||||
|
||||
@ -42,7 +42,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
|
||||
|
||||
// The cache key includes all the parameters to generate_code + vec_size + dev_idx
|
||||
std::stringstream ss;
|
||||
ss << nInputs << '_' << nOutputs << f;
|
||||
ss << nInputs << "_" << nOutputs << f;
|
||||
ss << f_inputs_type_str << compute_type_str << result_type_str;
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
ss << extra_args_types;
|
||||
@ -144,7 +144,7 @@ static inline void launch_jitted_unrolled_kernel_dynamic(
|
||||
|
||||
// The cache key includes all the parameters to generate_code + dev_idx
|
||||
std::stringstream ss;
|
||||
ss << nInputs << '_' << nOutputs << f;
|
||||
ss << nInputs << "_" << nOutputs << f;
|
||||
ss << f_inputs_type_str << compute_type_str << result_type_str;
|
||||
ss << contiguous << dynamic_casting;
|
||||
ss << static_cast<int>(at::cuda::jit::BinaryFuncVariant::NoScalar);
|
||||
|
||||
@ -52,10 +52,10 @@ TuningContext* getTuningContext() {
|
||||
std::ostream& operator<<(std::ostream& stream, const ResultEntry& entry) {
|
||||
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
|
||||
if (!blaslog) {
|
||||
return stream << entry.key_ << ',' << entry.time_;
|
||||
return stream << entry.key_ << "," << entry.time_;
|
||||
}
|
||||
else {
|
||||
return stream << entry.key_ << ',' << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
|
||||
return stream << entry.key_ << "," << entry.time_ << ",BLAS_PARAMS: " << entry.blas_sig_;
|
||||
}
|
||||
}
|
||||
|
||||
@ -156,10 +156,10 @@ void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std
|
||||
if (isNew) {
|
||||
static const bool blaslog = c10::utils::get_env("PYTORCH_TUNABLEOP_BLAS_LOG") == "1";
|
||||
if (!blaslog) {
|
||||
untuned_file << op_signature << ',' << params_signature << std::endl;
|
||||
untuned_file << op_signature << "," << params_signature << std::endl;
|
||||
}
|
||||
else {
|
||||
untuned_file << op_signature << ',' << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
|
||||
untuned_file << op_signature << "," << params_signature << ",BLAS_PARAMS: " << blas_signature << std::endl;
|
||||
}
|
||||
TUNABLE_LOG3("Untuned,", op_signature, ",", params_signature);
|
||||
}
|
||||
@ -201,7 +201,7 @@ void TuningResultsManager::InitRealtimeAppend(const std::string& filename, const
|
||||
|
||||
if(!file_exists || file_empty) {
|
||||
for(const auto& [key, val] : validators) {
|
||||
(*realtime_out_) << "Validator," << key << ',' << val << std::endl;
|
||||
(*realtime_out_) << "Validator," << key << "," << val << std::endl;
|
||||
realtime_out_->flush();
|
||||
}
|
||||
validators_written_ = true;
|
||||
@ -219,7 +219,7 @@ void TuningResultsManager::AppendResultLine(const std::string& op_sig, const std
|
||||
return;
|
||||
}
|
||||
|
||||
(*realtime_out_) << op_sig << ',' << param_sig << ',' << result << std::endl;
|
||||
(*realtime_out_) << op_sig << "," << param_sig << "," << result << std::endl;
|
||||
realtime_out_->flush(); //ensure immediate write to disk
|
||||
|
||||
TUNABLE_LOG3("Realtime append: ", op_sig, "(", param_sig, ") -> ", result);
|
||||
|
||||
@ -93,31 +93,31 @@ std::string cudnnTypeToString(cudnnDataType_t dtype) {
|
||||
return "CUDNN_DATA_UINT8x4";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
|
||||
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << '\n';
|
||||
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
|
||||
int nbDims = 0;
|
||||
int dimA[CUDNN_DIM_MAX];
|
||||
int strideA[CUDNN_DIM_MAX];
|
||||
cudnnDataType_t dtype{};
|
||||
cudnnGetTensorNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &nbDims, dimA, strideA);
|
||||
out << " type = " << cudnnTypeToString(dtype) << '\n';
|
||||
out << " nbDims = " << nbDims << '\n';
|
||||
out << " type = " << cudnnTypeToString(dtype) << "\n";
|
||||
out << " nbDims = " << nbDims << "\n";
|
||||
// Read out only nbDims of the arrays!
|
||||
out << " dimA = ";
|
||||
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << '\n';
|
||||
out << "\n";
|
||||
out << " strideA = ";
|
||||
for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << '\n';
|
||||
out << "\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -168,27 +168,27 @@ std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
|
||||
return "CUDNN_TENSOR_NHWC";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ')';
|
||||
oss << "(unknown cudnn tensor format " << static_cast<int>(tformat) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d) {
|
||||
out << "FilterDescriptor " << static_cast<void*>(d.desc()) << '\n';
|
||||
out << "FilterDescriptor " << static_cast<void*>(d.desc()) << "\n";
|
||||
int nbDims = 0;
|
||||
int dimA[CUDNN_DIM_MAX];
|
||||
cudnnDataType_t dtype{};
|
||||
cudnnTensorFormat_t tformat{};
|
||||
cudnnGetFilterNdDescriptor(d.desc(), CUDNN_DIM_MAX, &dtype, &tformat, &nbDims, dimA);
|
||||
out << " type = " << cudnnTypeToString(dtype) << '\n';
|
||||
out << " tensor_format = " << cudnnMemoryFormatToString(tformat) << '\n';
|
||||
out << " nbDims = " << nbDims << '\n';
|
||||
out << " type = " << cudnnTypeToString(dtype) << "\n";
|
||||
out << " tensor_format = " << cudnnMemoryFormatToString(tformat) << "\n";
|
||||
out << " nbDims = " << nbDims << "\n";
|
||||
// Read out only nbDims of the arrays!
|
||||
out << " dimA = ";
|
||||
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << '\n';
|
||||
out << "\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -346,15 +346,15 @@ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int6
|
||||
}
|
||||
|
||||
std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) {
|
||||
os << layer.layerId() << ':' << layer.key();
|
||||
os << layer.layerId() << ":" << layer.key();
|
||||
return os;
|
||||
}
|
||||
std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls) {
|
||||
os << "DynamicLayerStack[ ";
|
||||
for (const auto& layer : dls) {
|
||||
os << layer << ' ';
|
||||
os << layer << " ";
|
||||
}
|
||||
os << ']';
|
||||
os << "]";
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
|
||||
if (batched) {
|
||||
ss << "Batched[lvl=" << batched->level() << " dim=" << batched->bdim() << ", ";
|
||||
dumpTensor(ss, batched->value());
|
||||
ss << ']';
|
||||
ss << "]";
|
||||
return;
|
||||
}
|
||||
ss << "Tensor" << tensor.sizes();
|
||||
@ -36,7 +36,7 @@ void dumpTensor(std::ostream& ss, const Tensor& tensor) {
|
||||
ss << "dead, ";
|
||||
}
|
||||
dumpTensor(ss, wrapped->value());
|
||||
ss << ']';
|
||||
ss << "]";
|
||||
}
|
||||
|
||||
void TensorWrapper::refreshMetadata() {
|
||||
|
||||
@ -73,32 +73,32 @@ std::string miopenTypeToString(miopenDataType_t dtype) {
|
||||
return "miopenBFloat16";
|
||||
default:
|
||||
std::ostringstream oss;
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ')';
|
||||
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
|
||||
return oss.str();
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
|
||||
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << '\n';
|
||||
out << "TensorDescriptor " << static_cast<void*>(d.desc()) << "\n";
|
||||
int nbDims = 0;
|
||||
int dimA[MIOPEN_DIM_MAX];
|
||||
int strideA[MIOPEN_DIM_MAX];
|
||||
miopenDataType_t dtype;
|
||||
miopenGetTensorDescriptorSize(d.desc(), &nbDims);
|
||||
miopenGetTensorDescriptor(d.desc(), &dtype, dimA, strideA);
|
||||
out << " type = " << miopenTypeToString(dtype) << '\n';
|
||||
out << " nbDims = " << nbDims << '\n';
|
||||
out << " type = " << miopenTypeToString(dtype) << "\n";
|
||||
out << " nbDims = " << nbDims << "\n";
|
||||
// Read out only nbDims of the arrays!
|
||||
out << " dimA = ";
|
||||
for (auto i : ArrayRef<int>{dimA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << '\n';
|
||||
out << "\n";
|
||||
out << " strideA = ";
|
||||
for (auto i : ArrayRef<int>{strideA, static_cast<size_t>(nbDims)}) {
|
||||
out << i << ", ";
|
||||
}
|
||||
out << '\n';
|
||||
out << "\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -91,7 +91,7 @@ struct OperationInfo : BaseInfo {
|
||||
std::stringstream kernelStr;
|
||||
kernelStr << kernelName;
|
||||
for (const Tensor& tensor : tensors) {
|
||||
kernelStr << ':' << BaseInfo::buildTensorString(tensor, includeBufferId);
|
||||
kernelStr << ":" << BaseInfo::buildTensorString(tensor, includeBufferId);
|
||||
}
|
||||
return kernelStr.str();
|
||||
}
|
||||
|
||||
@ -39,9 +39,9 @@ std::string BaseInfo::buildTensorString(const Tensor& tensor, bool includeBuffer
|
||||
// see comments for INCLUDE_BUFFER_ID
|
||||
if (includeBufferId && deviceType == at::kMPS) {
|
||||
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ':' << buffer.retainCount << ')';
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":" << buffer.retainCount << ")";
|
||||
}
|
||||
tensorStr << ':' << tensor.scalar_type() << tensor.sizes();
|
||||
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
|
||||
return tensorStr.str();
|
||||
} else {
|
||||
return "undefined";
|
||||
|
||||
@ -167,7 +167,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co
|
||||
std::stringstream ss;
|
||||
ss << arg_name << " should be greater than zero but got (";
|
||||
std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,", "));
|
||||
ss << args.back() << ")" << " (while checking arguments for " << c << ')';
|
||||
ss << args.back() << ")" << " (while checking arguments for " << c << ")";
|
||||
TORCH_CHECK(false, ss.str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -639,7 +639,7 @@ static std::ostream& operator<<(std::ostream & out, const ConvParams<T>& params)
|
||||
<< " deterministic = " << params.deterministic
|
||||
<< " cudnn_enabled = " << params.cudnn_enabled
|
||||
<< " allow_tf32 = " << params.allow_tf32
|
||||
<< '}';
|
||||
<< "}";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -3541,9 +3541,9 @@ Tensor _dyn_quant_matmul_4bit_cpu(
|
||||
const int64_t out_features) {
|
||||
auto M = inp.size(0);
|
||||
TORCH_CHECK(
|
||||
inp.dtype() == kFloat,
|
||||
inp.dtype() == kFloat || (inp.dtype() == kBFloat16 && block_size == in_features),
|
||||
__func__,
|
||||
" : expect input to be 32-bit float tensor.");
|
||||
" : expect input to be float32 or bfloat16 tensor.");
|
||||
TORCH_CHECK(
|
||||
block_size == in_features ||
|
||||
(!(block_size % 32) && !(in_features % block_size)),
|
||||
|
||||
@ -847,7 +847,7 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional<int64_t
|
||||
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
|
||||
<< ", window="; \
|
||||
if (window.defined()) { \
|
||||
SS << window.toString() << '{' << window.sizes() << '}'; \
|
||||
SS << window.toString() << "{" << window.sizes() << "}"; \
|
||||
} else { \
|
||||
SS << "None"; \
|
||||
} \
|
||||
@ -1046,7 +1046,7 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const std::optional<int64_
|
||||
<< ", hop_length=" << hop_length << ", win_length=" << win_length \
|
||||
<< ", window="; \
|
||||
if (window.defined()) { \
|
||||
SS << window.toString() << '{' << window.sizes() << '}'; \
|
||||
SS << window.toString() << "{" << window.sizes() << "}"; \
|
||||
} else { \
|
||||
SS << "None"; \
|
||||
} \
|
||||
|
||||
@ -523,7 +523,7 @@ Tensor _functional_assert_async_msg_cpu(
|
||||
}
|
||||
|
||||
void _print(std::string_view s) {
|
||||
std::cout << s << '\n';
|
||||
std::cout << s << "\n";
|
||||
}
|
||||
|
||||
// Sorting-based algorithm for isin(); used when the number of test elements is
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <cmath>
|
||||
#include <c10/util/Unroll.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
@ -793,6 +794,139 @@ bool can_use_kleidiai(
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t k,
|
||||
const uint16_t* lhs_bf16,
|
||||
const uint8_t* rhs_qs4cx,
|
||||
const float* rhs_scales,
|
||||
uint16_t* dst_bf16,
|
||||
float scalar_min,
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Roundup lambda for internal stride calculations
|
||||
auto roundup = [](size_t a, size_t b) { return ((a + b - 1) / b) * b; };
|
||||
|
||||
// Cast bfloat16 to float32 inline
|
||||
auto cast_bf16_to_f32 = [](uint16_t bf16_val) {
|
||||
uint32_t tmp = static_cast<uint32_t>(bf16_val) << 16;
|
||||
float f;
|
||||
std::memcpy(&f, &tmp, sizeof(f));
|
||||
return f;
|
||||
};
|
||||
|
||||
// Cast float32 to bfloat16 inline
|
||||
auto cast_f32_to_bf16 = [](float f) {
|
||||
uint32_t bits;
|
||||
std::memcpy(&bits, &f, sizeof(bits));
|
||||
return static_cast<uint16_t>(bits >> 16);
|
||||
};
|
||||
|
||||
// Quantization pack lambda (channelwise QA8DX)
|
||||
auto quant_pack_8bit_channelwise =
|
||||
[&](size_t M, size_t K, const uint16_t* src_bf16, int8_t* dst_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const uint16_t* row_ptr = src_bf16 + i * K;
|
||||
// find min/max
|
||||
float mn = FLT_MAX, mx = -FLT_MAX;
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
mn = std::min(mn, v);
|
||||
mx = std::max(mx, v);
|
||||
}
|
||||
float rmin = std::min(0.0f, mn);
|
||||
float rmax = std::max(0.0f, mx);
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
float scale = (rmin == rmax) ? 1.f : (qmax - qmin) / (rmax - rmin);
|
||||
float recip = scale ? 1.0f / scale : 0.0f;
|
||||
int32_t zp;
|
||||
float des_min = rmin * scale;
|
||||
float des_max = rmax * scale;
|
||||
float err_min = qmin + des_min;
|
||||
float err_max = qmax + des_max;
|
||||
float zp_f =
|
||||
(err_min + err_max) > 0 ? qmin - des_min : qmax - des_max;
|
||||
zp_f = std::clamp(zp_f, qmin, qmax);
|
||||
zp = std::lrintf(zp_f);
|
||||
int8_t* out_ptr = dst_qa8dx + i * dst_stride;
|
||||
// store header
|
||||
*reinterpret_cast<float*>(out_ptr) = recip;
|
||||
*reinterpret_cast<int32_t*>(out_ptr + sizeof(float)) = -zp;
|
||||
out_ptr += sizeof(float) + sizeof(int32_t);
|
||||
// quantize
|
||||
for (size_t j = 0; j < K; ++j) {
|
||||
float v = cast_bf16_to_f32(row_ptr[j]);
|
||||
int32_t q = static_cast<int32_t>(std::round(v * scale)) + zp;
|
||||
q = std::clamp(
|
||||
q, static_cast<int32_t>(kI8Min), static_cast<int32_t>(kI8Max));
|
||||
*out_ptr++ = static_cast<int8_t>(q);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// MatMul lambda (MXN x MXK -> MNXK BF16)
|
||||
auto matmul_kernel = [&](size_t M,
|
||||
size_t N,
|
||||
size_t K,
|
||||
const int8_t* lhs,
|
||||
const uint8_t* rhs,
|
||||
const float* scales,
|
||||
uint16_t* dst,
|
||||
float lo,
|
||||
float hi) {
|
||||
const size_t lhs_stride =
|
||||
K * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
const size_t rhs_stride = roundup(K, 2) / 2;
|
||||
for (size_t i = 0; i < M; ++i) {
|
||||
const int8_t* lhs_row = lhs + i * lhs_stride;
|
||||
for (size_t j = 0; j < N; ++j) {
|
||||
int32_t acc = 0;
|
||||
const int8_t* lptr = lhs_row;
|
||||
const uint8_t* rptr = rhs + j * rhs_stride;
|
||||
float lhs_scale = *reinterpret_cast<const float*>(lptr);
|
||||
int32_t lhs_off =
|
||||
*reinterpret_cast<const int32_t*>(lptr + sizeof(float));
|
||||
lptr += sizeof(float) + sizeof(int32_t);
|
||||
for (size_t t = 0; t < K; ++t) {
|
||||
int32_t lv = static_cast<int32_t>(lptr[t]);
|
||||
uint8_t bv = rptr[t / 2];
|
||||
int32_t rv = ((t & 1) == 0) ? (static_cast<int32_t>(bv & 0xF) - 8)
|
||||
: (static_cast<int32_t>(bv >> 4) - 8);
|
||||
acc += lv * rv + lhs_off * rv;
|
||||
}
|
||||
float res = static_cast<float>(acc) * scales[j] * lhs_scale;
|
||||
if (bias) {
|
||||
res += bias[j];
|
||||
}
|
||||
res = std::clamp(res, lo, hi);
|
||||
*dst++ = cast_f32_to_bf16(res);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// allocate and run
|
||||
std::unique_ptr<int8_t[]> packed(
|
||||
new int8_t[m * (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t))]);
|
||||
quant_pack_8bit_channelwise(m, k, lhs_bf16, packed.get());
|
||||
matmul_kernel(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
packed.get(),
|
||||
rhs_qs4cx,
|
||||
rhs_scales,
|
||||
dst_bf16,
|
||||
scalar_min,
|
||||
scalar_max);
|
||||
}
|
||||
|
||||
/**
|
||||
* The Int4 quantized weights must be represented as a uint8 tensor
|
||||
* For matrix multiplication with a weight shape of (N x K)
|
||||
@ -819,21 +953,21 @@ void dyn_quant_pack_4bit_weight_kernel(
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
if (can_use_kleidiai(scales_zeros, K, block_size)) {
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, weights.scalar_type());
|
||||
packed_weights.resize_({weight_packed_size});
|
||||
kleidiai::kai_pack_int4_rhs(
|
||||
packed_weights, weights, scales_zeros, bias, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
TORCH_CHECK(
|
||||
bias.has_value() == 0,
|
||||
__func__,
|
||||
" : Bias is unsupported in reference implementation");
|
||||
packed_weights = packed_weights.to(kFloat);
|
||||
auto weight_reshaped = weights.view({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
|
||||
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
|
||||
auto weight_reshaped = weights.reshape({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.reshape({-1}).to(kFloat);
|
||||
std::vector<at::Tensor> tensors_to_cat = {weight_reshaped, scales_zeros_reshaped};
|
||||
if (bias.has_value()) {
|
||||
tensors_to_cat.push_back(bias.value().view({-1}).to(kFloat));
|
||||
}
|
||||
auto res = at::cat(tensors_to_cat, 0);
|
||||
packed_weights.resize_(res.sizes()).copy_(res);
|
||||
}
|
||||
}
|
||||
@ -847,7 +981,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
const float* rhs_scales_f32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
|
||||
|
||||
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
|
||||
@ -857,6 +992,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
// required format for matmul
|
||||
auto input_quant_pack_8bit_channelwise =
|
||||
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -877,8 +1015,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
}
|
||||
|
||||
// Maximum/minimum int8 values
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -904,7 +1042,7 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
|
||||
// Round to nearest integer
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
|
||||
|
||||
@ -922,8 +1060,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
|
||||
v0_s32 = v0_s32 + nudged_zero_point0;
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(kI8Min));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(kI8Max));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -987,6 +1125,10 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[n_idx];
|
||||
}
|
||||
|
||||
// Clamp (min-max) operation
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
@ -1007,12 +1149,16 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float* rhs_scales_fp32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
float scalar_max,
|
||||
const float* bias) {
|
||||
// Lambda for LHS quantization
|
||||
auto lhs_quant_pack = [&](size_t m,
|
||||
size_t k,
|
||||
const float* lhs_f32,
|
||||
int8_t* lhs_qa8dx) {
|
||||
constexpr int8_t kI8Min = std::numeric_limits<std::int8_t>::lowest();
|
||||
constexpr int8_t kI8Max = std::numeric_limits<std::int8_t>::max();
|
||||
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
@ -1028,8 +1174,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
min0 = std::min(src0_0, min0);
|
||||
}
|
||||
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
constexpr float qmin = static_cast<float>(kI8Min);
|
||||
constexpr float qmax = static_cast<float>(kI8Max);
|
||||
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
@ -1046,7 +1192,7 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
|
||||
zero_point0 = std::max(zero_point0, qmin);
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
const int32_t nudged_zero_point0 = std::lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
|
||||
|
||||
@ -1059,9 +1205,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
v0_s32 = std::max(
|
||||
std::min(
|
||||
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
|
||||
static_cast<int32_t>(INT8_MIN));
|
||||
std::min(v0_s32 + nudged_zero_point0, static_cast<int32_t>(kI8Max)),
|
||||
static_cast<int32_t>(kI8Min));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -1118,6 +1263,11 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
if (bias) {
|
||||
main_acc += bias[col_idx];
|
||||
}
|
||||
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
|
||||
@ -1128,28 +1278,27 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic Input Quant 4 bit weights matmul execution flow
|
||||
(INT4 Weights + FP scales + FP32 Bias)
|
||||
FP32 Input Packed Buffer
|
||||
| |
|
||||
Quantize Cast
|
||||
to INT8 to INT8
|
||||
| |
|
||||
v v
|
||||
INT8 Input INT8 Weights
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
INT8 Matrix Multiplication
|
||||
|
|
||||
v
|
||||
FP32 Dequantized and Accumulate in FP32
|
||||
|
|
||||
v
|
||||
FP32 Final Output
|
||||
|
||||
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||
* Dynamic INT4 weight-only MatMul with per-row input quantization.
|
||||
*
|
||||
* Execution Flow:
|
||||
*
|
||||
* (INT4 Weights + FP Scales [+ optional Bias])
|
||||
*
|
||||
* Input (FP32 or BF16) Packed Weight Buffer
|
||||
* | |
|
||||
* Row-wise Quantization (INT8) |
|
||||
* | |
|
||||
* INT8 Input Activation INT4 Quantized Weights + Scales
|
||||
* \ /
|
||||
* \ /
|
||||
* Quantized Matrix Multiply
|
||||
* |
|
||||
* Output Tensor (BF16 or FP32)
|
||||
*
|
||||
* Notes:
|
||||
* - Groupwise kernels expect BF16 scales
|
||||
* - Channelwise kernels expect FP32 scales
|
||||
* - Bias is currently unsupported in fallback path
|
||||
*/
|
||||
void dyn_quant_matmul_4bit_kernel(
|
||||
const Tensor& output,
|
||||
@ -1161,65 +1310,75 @@ void dyn_quant_matmul_4bit_kernel(
|
||||
const int64_t block_size) {
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size, inp.scalar_type());
|
||||
if (weight_packed_size == packed_weights.numel()) {
|
||||
// KleidiAI interface internally handles the Channelwise and groupwise
|
||||
// distinction
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(
|
||||
output, inp, packed_weights, M, N, K, block_size);
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(output, inp, packed_weights, M, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
|
||||
const auto weights_size = N * K / 2;
|
||||
// The weights needs to be in uint8_t data type after quantization
|
||||
auto extracted_weights =
|
||||
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
|
||||
auto float32_scales =
|
||||
(packed_weights.narrow(
|
||||
0, weights_size, packed_weights.size(0) - weights_size))
|
||||
.to(kFloat);
|
||||
uint8_t* rhs_4bit =
|
||||
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
|
||||
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_size,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
block_size == K || (!(block_size % 32) && !(K % block_size)),
|
||||
__func__,
|
||||
": Group size should be multiple 32 or in_features [",
|
||||
K,
|
||||
"]. Provided ",
|
||||
block_size);
|
||||
{
|
||||
void* input = inp.data_ptr();
|
||||
void* dst = output.data_ptr();
|
||||
|
||||
// Extract weights, sclaes and biases form from packed tensor
|
||||
const int weights_elements = N * K / 2;
|
||||
const int scale_elements = N * (K / block_size);
|
||||
TORCH_CHECK(packed_weights.numel() >= (weights_elements + scale_elements), "Invalid packed weight tensor size");
|
||||
|
||||
auto extracted_weights = packed_weights.narrow(0, 0, weights_elements).to(kByte);
|
||||
auto extracted_scales_and_bias = packed_weights.narrow(0, weights_elements, packed_weights.size(0) - weights_elements).to(kFloat);
|
||||
auto float32_scales = extracted_scales_and_bias.narrow(0, 0, scale_elements);
|
||||
|
||||
int bias_elements = packed_weights.numel() - (weights_elements + scale_elements);
|
||||
float* weight_scales = float32_scales.data_ptr<float>();
|
||||
|
||||
void* bias_data = nullptr;
|
||||
if (bias_elements) {
|
||||
auto float32_bias = extracted_scales_and_bias.narrow(0, scale_elements, bias_elements);
|
||||
TORCH_CHECK(float32_bias.size(0) == N, "Expected bias length to match output dimension");
|
||||
bias_data = float32_bias.data_ptr();
|
||||
|
||||
}
|
||||
// 2 elements of 4 bit weights are packed into 1 uint8 packet
|
||||
uint8_t* weights_4bit = reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
|
||||
// Dispatch to reference kernels
|
||||
if (inp.scalar_type() == at::kBFloat16) {
|
||||
// BF16 input, BF16 output
|
||||
constexpr float BF16_MAX = 3.38953139e+38f;
|
||||
constexpr float BF16_MIN = -BF16_MAX;
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel_bf16(
|
||||
M, N, K,
|
||||
(uint16_t*)input, weights_4bit, weight_scales,
|
||||
(uint16_t*)dst, BF16_MIN, BF16_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for BF16 fallback");
|
||||
}
|
||||
} else if (inp.scalar_type() == at::kFloat) {
|
||||
// FP32 input, FP32 output
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M, N, K,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M, N, K, block_size,
|
||||
(float*)input, weights_4bit, weight_scales,
|
||||
(float*)dst, -FLT_MAX, FLT_MAX, (float*)bias_data);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported block size for FP32 fallback");
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input/output dtype combination for int4mm kernel");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
}
|
||||
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
|
||||
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
|
||||
|
||||
@ -5,69 +5,11 @@
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
// ROCm 6.3 is planned to have these functions, but until then here they are.
|
||||
#if defined(USE_ROCM)
|
||||
#include <device_functions.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
|
||||
#if (defined(__gfx942__)) && \
|
||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
|
||||
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
|
||||
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
|
||||
union {
|
||||
__hip_bfloat162_raw bf162_raw;
|
||||
vec_short2 vs2;
|
||||
} u{static_cast<__hip_bfloat162_raw>(value)};
|
||||
u.vs2 = __builtin_amdgcn_flat_atomic_fadd_v2bf16((vec_short2*)address, u.vs2);
|
||||
return static_cast<__hip_bfloat162>(u.bf162_raw);
|
||||
#else
|
||||
static_assert(sizeof(unsigned int) == sizeof(__hip_bfloat162_raw));
|
||||
union u_hold {
|
||||
__hip_bfloat162_raw h2r;
|
||||
unsigned int u32;
|
||||
};
|
||||
u_hold old_val, new_val;
|
||||
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
do {
|
||||
new_val.h2r = __hadd2(old_val.h2r, value);
|
||||
} while (!__hip_atomic_compare_exchange_strong(
|
||||
(unsigned int*)address, &old_val.u32, new_val.u32,
|
||||
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
|
||||
return old_val.h2r;
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ inline __half2 preview_unsafeAtomicAdd(__half2* address, __half2 value) {
|
||||
#if (defined(__gfx942__)) && \
|
||||
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2f16)
|
||||
// The api expects an ext_vector_type of half
|
||||
typedef _Float16 __attribute__((ext_vector_type(2))) vec_fp162;
|
||||
static_assert(sizeof(vec_fp162) == sizeof(__half2_raw));
|
||||
union {
|
||||
__half2_raw h2r;
|
||||
vec_fp162 fp16;
|
||||
} u {static_cast<__half2_raw>(value)};
|
||||
u.fp16 = __builtin_amdgcn_flat_atomic_fadd_v2f16((vec_fp162*)address, u.fp16);
|
||||
return static_cast<__half2>(u.h2r);
|
||||
#else
|
||||
static_assert(sizeof(__half2_raw) == sizeof(unsigned int));
|
||||
union u_hold {
|
||||
__half2_raw h2r;
|
||||
unsigned int u32;
|
||||
};
|
||||
u_hold old_val, new_val;
|
||||
old_val.u32 = __hip_atomic_load((unsigned int*)address, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
|
||||
do {
|
||||
new_val.h2r = __hadd2(old_val.h2r, value);
|
||||
} while (!__hip_atomic_compare_exchange_strong(
|
||||
(unsigned int*)address, &old_val.u32, new_val.u32,
|
||||
__ATOMIC_RELAXED, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT));
|
||||
return old_val.h2r;
|
||||
#endif
|
||||
}
|
||||
#define ATOMICADD preview_unsafeAtomicAdd
|
||||
#define ATOMICADD unsafeAtomicAdd
|
||||
#define NATIVE_ZERO_BF16 __float2bfloat16(0.0f)
|
||||
#else
|
||||
#define ATOMICADD atomicAdd
|
||||
|
||||
@ -11,7 +11,7 @@ static inline std::ostream& operator<<(std::ostream& out, dim3 dim) {
|
||||
if (dim.y == 1 && dim.z == 1) {
|
||||
out << dim.x;
|
||||
} else {
|
||||
out << '[' << dim.x << ',' << dim.y << ',' << dim.z << ']';
|
||||
out << "[" << dim.x << "," << dim.y << "," << dim.z << "]";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
@ -27,7 +27,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "input_mult=[";
|
||||
for (int i = 0; i < 3; i++) {
|
||||
if (i != 0) {
|
||||
out << ',';
|
||||
out << ",";
|
||||
}
|
||||
out << config.input_mult[i];
|
||||
}
|
||||
@ -35,7 +35,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "output_mult=[";
|
||||
for (int i = 0; i < 2; i++) {
|
||||
if (i != 0) {
|
||||
out << ',';
|
||||
out << ",";
|
||||
}
|
||||
out << config.output_mult[i];
|
||||
}
|
||||
@ -49,7 +49,7 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
||||
out << "block=" << config.block() << ", ";
|
||||
out << "grid=" << config.grid() << ", ";
|
||||
out << "global_memory_size=" << config.global_memory_size();
|
||||
out << ')';
|
||||
out << ")";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -1101,6 +1101,19 @@ _scaled_mxfp8_mxfp8(
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
}
|
||||
|
||||
void
|
||||
_check_mxfp4_support() {
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// Only on B200 GPUs
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
// B200 = 10.0, B300 = 10.3
|
||||
dprops->major == 10,
|
||||
"MXFP4 scaling only supported in CUDA for B200/B300"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
Tensor&
|
||||
_scaled_mxfp4_mxfp4(
|
||||
@ -1113,6 +1126,7 @@ _scaled_mxfp4_mxfp4(
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
_check_mxfp4_support();
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
|
||||
@ -364,9 +364,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
// reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
|
||||
// stride_output_h + group_count);
|
||||
|
||||
// std::cout << "PTRS " << mat_a.data_ptr() << ' ' << mat_b.data_ptr() << "
|
||||
// std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << "
|
||||
// "
|
||||
// << out.data_ptr() << ' ' << scale_a.data_ptr() << ' '
|
||||
// << out.data_ptr() << " " << scale_a.data_ptr() << " "
|
||||
// << scale_b.data_ptr() << "\n";
|
||||
// for (int i = 0; i < group_count; i++) {
|
||||
// std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n";
|
||||
|
||||
@ -1057,14 +1057,14 @@ std::string generate_code(
|
||||
// TODO these arrays are potentially of the different types, use function
|
||||
// traits to determine the types
|
||||
declare_load_arrays << f_inputs_type << " arg" << std::to_string(i)
|
||||
<< '[' << std::to_string(thread_work_size) << "];\n";
|
||||
<< "[" << std::to_string(thread_work_size) << "];\n";
|
||||
}
|
||||
env.s("declare_load_arrays", declare_load_arrays.str());
|
||||
|
||||
std::stringstream declare_store_arrays;
|
||||
for (int i = 0; i < nOutputs; i++) {
|
||||
declare_store_arrays << result_type << " out" << std::to_string(i)
|
||||
<< '[' << std::to_string(thread_work_size) << "];\n";
|
||||
<< "[" << std::to_string(thread_work_size) << "];\n";
|
||||
}
|
||||
env.s("declare_store_arrays", declare_store_arrays.str());
|
||||
|
||||
@ -1217,7 +1217,7 @@ std::string generate_code(
|
||||
for (const auto i : c10::irange(nInputs)){
|
||||
auto i_string = std::to_string(i);
|
||||
vector_inputs << "auto * input" << i_string <<
|
||||
" = reinterpret_cast<const scalar_t*>(data[" << i_string << '+' << nOutputs << "])" <<
|
||||
" = reinterpret_cast<const scalar_t*>(data[" << i_string << "+" << nOutputs << "])" <<
|
||||
" + block_work_size * idx;\n";
|
||||
}
|
||||
env.s("vector_inputs", vector_inputs.str());
|
||||
@ -1543,17 +1543,17 @@ NvrtcFunction jit_pwise_function(
|
||||
|
||||
// Constructs file path by appending constructed cubin name to cache path
|
||||
std::stringstream ss;
|
||||
ss << *cache_dir << '/';
|
||||
ss << *cache_dir << "/";
|
||||
ss << kernel_name;
|
||||
#ifdef USE_ROCM
|
||||
ss << "_arch" << prop->gcnArchName;
|
||||
#else
|
||||
ss << "_arch" << cuda_major << '.' << cuda_minor;
|
||||
ss << "_arch" << cuda_major << "." << cuda_minor;
|
||||
#endif
|
||||
ss << "_nvrtc" << nvrtc_major << '.' << nvrtc_minor;
|
||||
ss << "_nvrtc" << nvrtc_major << "." << nvrtc_minor;
|
||||
ss << (compile_to_sass ? "_sass" : "_ptx");
|
||||
ss << '_' << code.length();
|
||||
ss << '_' << hash_code;
|
||||
ss << "_" << code.length();
|
||||
ss << "_" << hash_code;
|
||||
file_path = ss.str();
|
||||
|
||||
std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};
|
||||
|
||||
@ -82,15 +82,15 @@ namespace native {
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ConvolutionParams& params) {
|
||||
out << "ConvolutionParams \n"
|
||||
<< " memory_format = " << params.memory_format << '\n'
|
||||
<< " data_type = " << cudnnTypeToString(params.dataType) << '\n'
|
||||
<< " padding = " << ArrayRef<int>{params.padding} << '\n'
|
||||
<< " stride = " << ArrayRef<int>{params.stride} << '\n'
|
||||
<< " dilation = " << ArrayRef<int>{params.dilation} << '\n'
|
||||
<< " groups = " << params.groups << '\n'
|
||||
<< " memory_format = " << params.memory_format << "\n"
|
||||
<< " data_type = " << cudnnTypeToString(params.dataType) << "\n"
|
||||
<< " padding = " << ArrayRef<int>{params.padding} << "\n"
|
||||
<< " stride = " << ArrayRef<int>{params.stride} << "\n"
|
||||
<< " dilation = " << ArrayRef<int>{params.dilation} << "\n"
|
||||
<< " groups = " << params.groups << "\n"
|
||||
<< " deterministic = " << (params.deterministic ? "true" : "false")
|
||||
<< '\n'
|
||||
<< " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << '\n';
|
||||
<< "\n"
|
||||
<< " allow_tf32 = " << (params.allow_tf32 ? "true" : "false") << "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
@ -173,16 +173,16 @@ std::string repro_from_args(const ConvolutionParams& params) {
|
||||
at::globalContext().float32Precision(
|
||||
at::Float32Backend::CUDA, at::Float32Op::MATMUL) ==
|
||||
at::Float32Precision::TF32)
|
||||
<< '\n';
|
||||
<< "\n";
|
||||
ss << "torch.backends.cudnn.benchmark = "
|
||||
<< pybool(at::globalContext().benchmarkCuDNN()) << '\n';
|
||||
<< pybool(at::globalContext().benchmarkCuDNN()) << "\n";
|
||||
ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic)
|
||||
<< '\n';
|
||||
<< "\n";
|
||||
ss << "torch.backends.cudnn.allow_tf32 = " << pybool(params.allow_tf32)
|
||||
<< '\n';
|
||||
<< "\n";
|
||||
ss << "data = torch.randn(" << ArrayRef<int>(params.input_size, dim)
|
||||
<< ", dtype=" << full_dtype << ", ";
|
||||
ss << "device='cuda', requires_grad=True)" << to_channels_last << '\n';
|
||||
ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n";
|
||||
ss << "net = torch.nn.Conv" << dim - 2 << "d(" << in_channels << ", "
|
||||
<< out_channels << ", ";
|
||||
ss << "kernel_size=" << ArrayRef<int>(¶ms.weight_size[2], dim - 2)
|
||||
@ -192,7 +192,7 @@ std::string repro_from_args(const ConvolutionParams& params) {
|
||||
ss << "dilation=" << ArrayRef<int>(params.dilation, dim - 2) << ", ";
|
||||
ss << "groups=" << params.groups << ")\n";
|
||||
ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last
|
||||
<< '\n';
|
||||
<< "\n";
|
||||
ss << "out = net(data)\n";
|
||||
ss << "out.backward(torch.randn_like(out))\n";
|
||||
ss << "torch.cuda.synchronize()\n\n";
|
||||
|
||||
@ -93,10 +93,11 @@ std::ostream& operator<<(std::ostream& out, const ConvolutionArgs& args) {
|
||||
<< "input: " << args.idesc // already has a trailing newline
|
||||
<< "output: " << args.odesc // already has a trailing newline
|
||||
<< "weight: " << args.wdesc // already has a trailing newline
|
||||
<< "Pointer addresses: " << '\n'
|
||||
<< " input: " << args.input.const_data_ptr() << '\n'
|
||||
<< " output: " << args.output.const_data_ptr() << '\n'
|
||||
<< " weight: " << args.weight.const_data_ptr() << '\n';
|
||||
<< "Pointer addresses: "
|
||||
<< "\n"
|
||||
<< " input: " << args.input.const_data_ptr() << "\n"
|
||||
<< " output: " << args.output.const_data_ptr() << "\n"
|
||||
<< " weight: " << args.weight.const_data_ptr() << "\n";
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -21,18 +21,27 @@ void kai_pack_int4_rhs(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
if (weight.scalar_type() == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -63,19 +72,29 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype) {
|
||||
size_t packed_size = n * k;
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
if (tensor_dtype == at::kBFloat16) {
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
} else {
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
}
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
@ -148,8 +167,7 @@ static void kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -259,8 +277,7 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
const auto lhs_src_ptr = lhs_native_mtx_f32 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_idx, k, mr, kr, sr);
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
@ -320,19 +337,144 @@ static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
});
|
||||
}
|
||||
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
static void kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
// Kernel IDs for GEMM and GEMV
|
||||
constexpr kai_kernel_id gemm_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm;
|
||||
constexpr kai_kernel_id gemv_id =
|
||||
kai_kernel_id::matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod;
|
||||
|
||||
// Get total threads and select kernel
|
||||
const int64_t total_threads = at::get_num_threads();
|
||||
auto kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemv_id);
|
||||
if (cpuinfo_has_arm_i8mm() && m > 1) {
|
||||
kernel_packet = kai_select_bf16_channelwise_matmul_ukernel(gemm_id);
|
||||
}
|
||||
|
||||
// Thread blocking parameters
|
||||
const int64_t n_step = kernel_packet.ukernel.get_n_step();
|
||||
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||
const size_t kr = kernel_packet.ukernel.get_kr();
|
||||
const size_t sr = kernel_packet.ukernel.get_sr();
|
||||
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
|
||||
uint8_t* dst_act_mtx_bf16 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||
const uint8_t* lhs_native_mtx_bf16 =
|
||||
reinterpret_cast<const uint8_t*>(input.data_ptr());
|
||||
const uint8_t* rhs_packed_mtx_qs4cx =
|
||||
reinterpret_cast<const uint8_t*>(weight.data_ptr());
|
||||
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||
|
||||
constexpr int32_t element_size = sizeof(uint16_t);
|
||||
const size_t lhs_stride = k * element_size;
|
||||
const size_t dst_stride = n * element_size;
|
||||
|
||||
// LHS quantization packing
|
||||
int64_t vec_per_thread = get_vec_per_thread(m, total_threads, mr);
|
||||
int64_t num_threads = (m + vec_per_thread - 1) / vec_per_thread;
|
||||
const size_t src_stride = vec_per_thread * lhs_stride;
|
||||
|
||||
auto lhs_quant_pack = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto lhs_src_ptr = lhs_native_mtx_bf16 + thread_id * src_stride;
|
||||
const int64_t m_idx = thread_id * vec_per_thread;
|
||||
auto lhs_packed_ptr = lhs_packed_base +
|
||||
kernel_packet.kai_get_lhs_quant_pack_offset(m_idx, k, mr, kr, sr);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (m - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
vec_num,
|
||||
k,
|
||||
mr,
|
||||
kr,
|
||||
sr,
|
||||
0,
|
||||
(const uint16_t*)lhs_src_ptr,
|
||||
lhs_stride,
|
||||
lhs_packed_ptr);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
lhs_quant_pack(thread_id);
|
||||
}
|
||||
});
|
||||
|
||||
// Matrix multiplication
|
||||
vec_per_thread = get_vec_per_thread(n, total_threads, n_step);
|
||||
num_threads = (n + vec_per_thread - 1) / vec_per_thread;
|
||||
|
||||
auto mm = [=, &kernel_packet](int64_t thread_id) {
|
||||
const auto rhs_packed_ptr = rhs_packed_mtx_qs4cx +
|
||||
kernel_packet.ukernel.get_rhs_packed_offset(
|
||||
thread_id * vec_per_thread, k);
|
||||
auto dst_ptr = dst_act_mtx_bf16 +
|
||||
kernel_packet.ukernel.get_dst_offset(
|
||||
0, thread_id * vec_per_thread, dst_stride);
|
||||
const int64_t vec_num = (thread_id == num_threads - 1)
|
||||
? (n - vec_per_thread * thread_id)
|
||||
: vec_per_thread;
|
||||
|
||||
kernel_packet.ukernel.run_matmul(
|
||||
m,
|
||||
vec_num,
|
||||
k,
|
||||
lhs_packed_base,
|
||||
rhs_packed_ptr,
|
||||
(uint16_t*)dst_ptr,
|
||||
dst_stride,
|
||||
element_size, // dst_stride_col
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
};
|
||||
|
||||
at::parallel_for(
|
||||
0, num_threads, /*grain_size=*/1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t thread_id = begin; thread_id < end; ++thread_id) {
|
||||
mm(thread_id);
|
||||
}
|
||||
});
|
||||
}
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
const at::Tensor& output,
|
||||
const at::Tensor& input,
|
||||
const at::Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
const auto input_dtype = input.dtype();
|
||||
|
||||
if (input_dtype == at::kBFloat16) {
|
||||
if (cpuinfo_has_arm_bf16()) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_bf16_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"BF16 Unsupported: CPU does not support BF16. Please use a CPU with BF16 support.");
|
||||
}
|
||||
} else if (input_dtype == at::kFloat) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Unsupported input data type: Only Bfloat16 and Float inputs are supported.");
|
||||
}
|
||||
} else if ((bl % 32 == 0) && (k % bl == 0)) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
output, input, weight, m, n, k, bl);
|
||||
}
|
||||
|
||||
@ -25,7 +25,8 @@ void kai_pack_int4_rhs(
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl);
|
||||
const int64_t bl,
|
||||
at::ScalarType tensor_dtype = at::kFloat);
|
||||
|
||||
/**
|
||||
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
|
||||
|
||||
@ -36,7 +36,8 @@ void kai_pack_rhs_groupwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
@ -73,7 +74,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
auto weight_packed_data =
|
||||
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||
const auto scales_data = scales.data_ptr<float>();
|
||||
|
||||
const auto scales_data = scales.to(kFloat).data_ptr<float>();
|
||||
|
||||
if (weight_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||
@ -83,7 +85,8 @@ void kai_pack_rhs_channelwise_int4(
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
float* bias_ptr =
|
||||
bias.has_value() ? bias.value().to(kFloat).data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
|
||||
@ -68,5 +68,39 @@ kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
|
||||
// Kernel Mapping - BF16 Channelwise
|
||||
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_bf16_qa8dxp_qs4cxp>
|
||||
bf16_channelwise_8bit_4bit_kernels = {
|
||||
{kai_kernel_id::
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod}}},
|
||||
{kai_kernel_id::matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
{{kai_get_m_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_n_step_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_mr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_nr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_kr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_sr_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_offset_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_get_dst_size_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm,
|
||||
kai_run_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm}}}};
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp kai_select_bf16_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return bf16_channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
||||
|
||||
@ -10,21 +10,32 @@
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm.h>
|
||||
#include <kai/ukernels/matmul/matmul_clamp_bf16_qai8dxp_qsi4cxp/kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_bf16_neon.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
|
||||
#include <kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
enum class kai_kernel_id {
|
||||
// FP32 inputs, 4-bit weights, FP32 output
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
|
||||
0, // Groupwise 4 bit GEMV
|
||||
0, // Groupwise 4-bit GEMV (per-group scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
|
||||
1, // Groupwise 4 bit GEMM
|
||||
1, // Groupwise 4-bit GEMM (per-group scales, NEON I8MM)
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
|
||||
2, // Channelwise 4 bit GEMV
|
||||
2, // Channelwise 4-bit GEMV (per-channel scales, NEON DOTPROD)
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
|
||||
3 // Channelwise 4 bit GEMM
|
||||
3, // Channelwise 4-bit GEMM (per-channel scales, NEON I8MM)
|
||||
|
||||
// BF16 inputs, 4-bit weights, BF16 output
|
||||
matmul_clamp_bf16_qai8dxp1x8_qsi4cxp8x8_1x8_neon_dotprod =
|
||||
4, // Channelwise 4-bit GEMV with BF16 input/output
|
||||
matmul_clamp_bf16_qai8dxp4x8_qsi4cxp8x8_8x8_neon_i8mm =
|
||||
5 // Channelwise 4-bit GEMM with BF16 input/output
|
||||
};
|
||||
|
||||
// Channelwise Kernel mapping
|
||||
@ -66,6 +77,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
@ -75,12 +89,71 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
|
||||
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// bf16 Channelwise Kernel mapping
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp {
|
||||
struct kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel ukernel;
|
||||
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
|
||||
size_t (*kai_get_lhs_packed_size)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
size_t (*kai_get_rhs_packed_size)(
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
void (*kai_run_lhs_quant_pack)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t m_idx_start,
|
||||
const void* lhs,
|
||||
size_t lhs_stride,
|
||||
void* lhs_packed);
|
||||
void (*kai_run_rhs_pack)(
|
||||
size_t num_groups,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
const uint8_t* rhs,
|
||||
const float* bias,
|
||||
const float* scale,
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_bf16_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_bf16_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
: ukernel(kernel),
|
||||
kai_get_lhs_packed_size(
|
||||
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_bf16_neon),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_bf16_neon){}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_bf16_qa8dxp_qs4cxp
|
||||
kai_select_bf16_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// Groupwise Kernel mapping
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
|
||||
@ -125,6 +198,9 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
|
||||
size_t(*kai_get_lhs_quant_pack_offset)(
|
||||
size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr
|
||||
);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
|
||||
@ -134,7 +210,8 @@ struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_get_lhs_quant_pack_offset(&kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32) {}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||
|
||||
@ -115,7 +115,7 @@ std::ostream& operator<<(
|
||||
std::copy(
|
||||
strides.begin(), strides.end() - 1, std::ostream_iterator<int>(oss, ","));
|
||||
oss << sizes.back();
|
||||
output << oss.str() << '}';
|
||||
output << oss.str() << "}";
|
||||
return output;
|
||||
}
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& out, const ConvParams& params) {
|
||||
<< " transposed = " << params.transposed
|
||||
<< " output_padding = " << IntArrayRef{params.output_padding}
|
||||
<< " groups = " << params.groups << " benchmark = " << params.benchmark
|
||||
<< " deterministic = " << params.deterministic << '}';
|
||||
<< " deterministic = " << params.deterministic << "}";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -91,25 +91,30 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#include <ATen/native/mps/Repeat_metallib.h>
|
||||
#endif
|
||||
|
||||
template <typename index_t>
|
||||
void computeRepeatIndices(const index_t* repeat_ptr,
|
||||
const int64_t* cumsum_ptr,
|
||||
index_t* result_ptr,
|
||||
int64_t size,
|
||||
int64_t result_size) {
|
||||
id<MTLBuffer> repeatBuffer = reinterpret_cast<id<MTLBuffer>>(repeat_ptr);
|
||||
id<MTLBuffer> cumsumBuffer = reinterpret_cast<id<MTLBuffer>>(cumsum_ptr);
|
||||
id<MTLBuffer> resultBuffer = reinterpret_cast<id<MTLBuffer>>(result_ptr);
|
||||
TORCH_CHECK(repeatBuffer && cumsumBuffer && resultBuffer);
|
||||
|
||||
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
|
||||
TORCH_CHECK(repeat.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
|
||||
std::string scalar_type;
|
||||
if constexpr (std::is_same_v<index_t, int32_t>) {
|
||||
if (repeat.scalar_type() == kInt) {
|
||||
scalar_type = "int32_t";
|
||||
} else if constexpr (std::is_same_v<index_t, int64_t>) {
|
||||
} else if (repeat.scalar_type() == kLong) {
|
||||
scalar_type = "int64_t";
|
||||
} else {
|
||||
TORCH_CHECK(false, "repeat_interleave: unsupported indexing data type");
|
||||
TORCH_CHECK(false, "repeats has to be Long or Int tensor");
|
||||
}
|
||||
if (repeat.size(0) == 0) {
|
||||
return at::empty_like(repeat, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
}
|
||||
Tensor repeat_ = repeat.contiguous();
|
||||
Tensor cumsum = repeat.cumsum(0);
|
||||
int64_t total = 0;
|
||||
if (output_size.has_value()) {
|
||||
total = output_size.value();
|
||||
} else {
|
||||
total = cumsum[-1].item<int64_t>();
|
||||
TORCH_CHECK((repeat >= 0).all().item<uint8_t>(), "repeats can not be negative");
|
||||
}
|
||||
|
||||
auto result = at::empty({total}, repeat.options());
|
||||
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
dispatch_sync(mpsStream->queue(), ^() {
|
||||
@ -121,20 +126,13 @@ void computeRepeatIndices(const index_t* repeat_ptr,
|
||||
getMPSProfiler().beginProfileKernel(pipelineState, "repeat_interleave:" + scalar_type, false);
|
||||
|
||||
[computeEncoder setComputePipelineState:pipelineState];
|
||||
mps::mtl_setArgs(computeEncoder, repeatBuffer, cumsumBuffer, resultBuffer, size);
|
||||
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, size);
|
||||
mps::mtl_setArgs(computeEncoder, repeat_, cumsum, result, repeat.size(0));
|
||||
mps::mtl_dispatch1DJob(computeEncoder, pipelineState, repeat.size(0));
|
||||
|
||||
getMPSProfiler().endProfileKernel(pipelineState);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
|
||||
Tensor output;
|
||||
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
|
||||
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
|
||||
});
|
||||
return output;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -89,13 +90,21 @@ static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor&
|
||||
auto clamp_shape = clamp_opt->sizes();
|
||||
auto input_shape = input_t.sizes();
|
||||
|
||||
TORCH_CHECK(num_clamp_dims <= num_input_dims,
|
||||
op_name + ": clamp tensor number of dims must not be greater than that of input tensor")
|
||||
if (num_clamp_dims > num_input_dims) {
|
||||
auto leading_dims = num_clamp_dims - num_input_dims;
|
||||
for (int64_t i = 0; i < leading_dims; ++i) {
|
||||
TORCH_CHECK(clamp_shape[i] == 1,
|
||||
op_name + ": clamp tensor leading shape must be 1 to broadcast with input tensor");
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_clamp_dims; i++)
|
||||
auto clamp_idx = num_clamp_dims - 1;
|
||||
auto input_idx = num_input_dims - 1;
|
||||
auto common_dims = std::min(num_clamp_dims, num_input_dims);
|
||||
for (int64_t i = 0; i < common_dims; ++i)
|
||||
// One of the indices is allowed to be 1; will be handled by broadcast
|
||||
TORCH_CHECK(clamp_shape[num_clamp_dims - 1 - i] == input_shape[num_input_dims - 1 - i] ||
|
||||
clamp_shape[num_clamp_dims - 1 - i] == 1 || input_shape[num_input_dims - 1 - i] == 1,
|
||||
TORCH_CHECK(clamp_shape[clamp_idx - i] == input_shape[input_idx - i] || clamp_shape[clamp_idx - i] == 1 ||
|
||||
input_shape[input_idx - i] == 1,
|
||||
op_name + ": clamp tensor trailing shape must match input tensor")
|
||||
}
|
||||
}
|
||||
@ -136,9 +145,6 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
|
||||
auto result_type = output_t.scalar_type();
|
||||
|
||||
IntArrayRef new_min_shape;
|
||||
IntArrayRef new_max_shape;
|
||||
|
||||
auto num_min_dims = min_opt->dim();
|
||||
auto num_max_dims = max_opt->dim();
|
||||
auto num_input_dims = input_t.dim();
|
||||
@ -146,24 +152,32 @@ static void clamp_tensor_out_mps(const Tensor& input_t,
|
||||
std::vector<int64_t> new_min_arr(num_input_dims);
|
||||
std::vector<int64_t> new_max_arr(num_input_dims);
|
||||
|
||||
if (has_min && num_min_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_min_dims, new_min_arr.data(), min_opt->sizes());
|
||||
new_min_shape = IntArrayRef(new_min_arr);
|
||||
}
|
||||
|
||||
if (has_max && num_max_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_max_dims, new_max_arr.data(), max_opt->sizes());
|
||||
new_max_shape = IntArrayRef(new_max_arr);
|
||||
}
|
||||
|
||||
Tensor min_opt_tensor;
|
||||
Tensor max_opt_tensor;
|
||||
|
||||
auto reshape_clamp_tensor = [&](const OptionalTensorRef clamp_tensor_ref,
|
||||
int64_t num_clamp_dims,
|
||||
std::vector<int64_t>& new_shape_storage) -> Tensor {
|
||||
IntArrayRef clamp_shape = clamp_tensor_ref->sizes();
|
||||
bool requires_view = false;
|
||||
|
||||
if (num_clamp_dims > num_input_dims) {
|
||||
clamp_shape = clamp_shape.slice(num_clamp_dims - num_input_dims);
|
||||
requires_view = true;
|
||||
} else if (num_clamp_dims < num_input_dims) {
|
||||
fill_new_shape(num_input_dims, num_clamp_dims, new_shape_storage.data(), clamp_shape);
|
||||
clamp_shape = IntArrayRef(new_shape_storage);
|
||||
requires_view = true;
|
||||
}
|
||||
|
||||
return requires_view ? (*clamp_tensor_ref).view(clamp_shape) : *clamp_tensor_ref;
|
||||
};
|
||||
|
||||
if (has_min) {
|
||||
min_opt_tensor = (num_min_dims < num_input_dims) ? (*min_opt).view(new_min_shape) : *min_opt;
|
||||
min_opt_tensor = reshape_clamp_tensor(min_opt, num_min_dims, new_min_arr);
|
||||
}
|
||||
if (has_max) {
|
||||
max_opt_tensor = (num_max_dims < num_input_dims) ? (*max_opt).view(new_max_shape) : *max_opt;
|
||||
max_opt_tensor = reshape_clamp_tensor(max_opt, num_max_dims, new_max_arr);
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
@ -301,12 +301,12 @@ class AvgPoolMicrokernelTester {
|
||||
ASSERT_NEAR(
|
||||
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
}
|
||||
}
|
||||
@ -396,12 +396,12 @@ class AvgPoolMicrokernelTester {
|
||||
ASSERT_NEAR(
|
||||
float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5001f)
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k];
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,7 +232,7 @@ class MaxPoolMicrokernelTester {
|
||||
ASSERT_EQ(
|
||||
uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k]))
|
||||
<< "at pixel " << i << ", channel " << k << ", n = " << n()
|
||||
<< ", ks = " << kh() << 'x' << kw() << " (" << ks()
|
||||
<< ", ks = " << kh() << "x" << kw() << " (" << ks()
|
||||
<< "), kc = " << kc();
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ inline std::vector<T> _expand_param_if_needed(
|
||||
std::ostringstream ss;
|
||||
ss << "expected " << param_name << " to be a single integer value or a "
|
||||
<< "list of " << expected_dim << " values to match the convolution "
|
||||
<< "dimensions, but got " << param_name << '=' << list_param;
|
||||
<< "dimensions, but got " << param_name << "=" << list_param;
|
||||
TORCH_CHECK(false, ss.str());
|
||||
} else {
|
||||
return list_param.vec();
|
||||
|
||||
@ -358,9 +358,9 @@ std::string Adapter::stringize() const {
|
||||
std::string device_type = get_device_type_str(properties.deviceType);
|
||||
VkPhysicalDeviceLimits limits = properties.limits;
|
||||
|
||||
ss << '{' << std::endl;
|
||||
ss << "{" << std::endl;
|
||||
ss << " Physical Device Info {" << std::endl;
|
||||
ss << " apiVersion: " << v_major << '.' << v_minor << std::endl;
|
||||
ss << " apiVersion: " << v_major << "." << v_minor << std::endl;
|
||||
ss << " driverversion: " << properties.driverVersion << std::endl;
|
||||
ss << " deviceType: " << device_type << std::endl;
|
||||
ss << " deviceName: " << properties.deviceName << std::endl;
|
||||
@ -371,7 +371,7 @@ std::string Adapter::stringize() const {
|
||||
|
||||
#define PRINT_LIMIT_PROP_VEC3(name) \
|
||||
ss << " " << std::left << std::setw(36) << #name << limits.name[0] \
|
||||
<< ',' << limits.name[1] << ',' << limits.name[2] << std::endl;
|
||||
<< "," << limits.name[1] << "," << limits.name[2] << std::endl;
|
||||
|
||||
ss << " Physical Device Limits {" << std::endl;
|
||||
PRINT_LIMIT_PROP(maxImageDimension1D);
|
||||
@ -425,7 +425,7 @@ std::string Adapter::stringize() const {
|
||||
;
|
||||
}
|
||||
ss << " ]" << std::endl;
|
||||
ss << '}';
|
||||
ss << "}";
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
|
||||
VK_RESULT_CASE(VK_ERROR_FORMAT_NOT_SUPPORTED)
|
||||
VK_RESULT_CASE(VK_ERROR_FRAGMENTED_POOL)
|
||||
default:
|
||||
out << "VK_ERROR_UNKNOWN (VkResult " << result << ')';
|
||||
out << "VK_ERROR_UNKNOWN (VkResult " << result << ")";
|
||||
break;
|
||||
}
|
||||
return out;
|
||||
@ -46,7 +46,7 @@ std::ostream& operator<<(std::ostream& out, const VkResult result) {
|
||||
//
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
|
||||
out << loc.function << " at " << loc.file << ':' << loc.line;
|
||||
out << loc.function << " at " << loc.file << ":" << loc.line;
|
||||
return out;
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ Error::Error(SourceLocation source_location, const char* cond, std::string msg)
|
||||
: msg_(std::move(msg)), source_location_{source_location} {
|
||||
std::ostringstream oss;
|
||||
oss << "Exception raised from " << source_location_ << ": ";
|
||||
oss << '(' << cond << ") is false! ";
|
||||
oss << "(" << cond << ") is false! ";
|
||||
oss << msg_;
|
||||
what_ = oss.str();
|
||||
}
|
||||
|
||||
@ -173,8 +173,8 @@ void QueryPool::extract_results() {
|
||||
|
||||
static std::string stringize(const VkExtent3D& extents) {
|
||||
std::stringstream ss;
|
||||
ss << '{' << extents.width << ", " << extents.height << ", " << extents.depth
|
||||
<< '}';
|
||||
ss << "{" << extents.width << ", " << extents.height << ", " << extents.depth
|
||||
<< "}";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
|
||||
@ -149,7 +149,7 @@ VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn(
|
||||
(void)flags;
|
||||
|
||||
std::stringstream stream;
|
||||
stream << layer_prefix << ' ' << message_code << ' ' << message << std::endl;
|
||||
stream << layer_prefix << " " << message_code << " " << message << std::endl;
|
||||
const std::string log = stream.str();
|
||||
|
||||
std::cout << log;
|
||||
|
||||
@ -253,7 +253,7 @@ using vec4 = vec<4u>;
|
||||
|
||||
// uvec3 is the type representing tensor extents. Useful for debugging.
|
||||
inline std::ostream& operator<<(std::ostream& os, const uvec3& v) {
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -61,6 +61,7 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_math_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_complex_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cub_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_cublas_handle_pool_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_device_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_distributions_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_dlconvertor_test.cpp
|
||||
|
||||
@ -246,7 +246,7 @@ void TestToCFloat() {
|
||||
void TestToString() {
|
||||
Tensor b = ones({3, 7}) * .0000001f;
|
||||
std::stringstream s;
|
||||
s << b << '\n';
|
||||
s << b << "\n";
|
||||
std::string expect = "1e-07 *";
|
||||
ASSERT_EQ_RESOLVED(s.str().substr(0, expect.size()), expect);
|
||||
}
|
||||
|
||||
77
aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp
Normal file
77
aten/src/ATen/test/cuda_cublas_handle_pool_test.cpp
Normal file
@ -0,0 +1,77 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Test concurrent access to getCurrentCUDABlasHandle and getCUDABlasLtWorkspace
|
||||
// to verify that the data race fix is working correctly
|
||||
|
||||
TEST(CUDABlasHandlePoolTest, ConcurrentGetAndClearWorkspaces) {
|
||||
if (!at::cuda::is_available()) {
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr int num_accessor_threads = 15;
|
||||
constexpr int num_clear_threads = 5;
|
||||
constexpr int iterations_per_thread = 50;
|
||||
|
||||
std::atomic<bool> stop{false};
|
||||
std::atomic<int> error_count{0};
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(num_accessor_threads + num_clear_threads);
|
||||
|
||||
// Launch accessor threads
|
||||
for (int i = 0; i < num_accessor_threads; ++i) {
|
||||
threads.emplace_back([&stop, &error_count]() {
|
||||
try {
|
||||
at::cuda::CUDAGuard device_guard(0);
|
||||
|
||||
while (!stop.load(std::memory_order_relaxed)) {
|
||||
const auto handle = at::cuda::getCurrentCUDABlasHandle();
|
||||
const auto workspace = at::cuda::getCUDABlasLtWorkspace();
|
||||
|
||||
if (handle == nullptr || workspace == nullptr) {
|
||||
error_count++;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Launch threads that clear workspaces
|
||||
for (int i = 0; i < num_clear_threads; ++i) {
|
||||
threads.emplace_back([&error_count]() {
|
||||
try {
|
||||
for (int j = 0; j < iterations_per_thread; ++j) {
|
||||
at::cuda::clearCublasWorkspaces();
|
||||
std::this_thread::yield();
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
error_count++;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Let them run for a bit
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
stop.store(true, std::memory_order_relaxed);
|
||||
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
EXPECT_EQ(error_count.load(), 0);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
c10::cuda::CUDACachingAllocator::init(1);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@ -33,7 +33,7 @@ struct Foo {
|
||||
static void apply(Tensor a, Tensor b) {
|
||||
scalar_type s = 1;
|
||||
std::stringstream ss;
|
||||
ss << "hello, dispatch: " << a.toString() << s << '\n';
|
||||
ss << "hello, dispatch: " << a.toString() << s << "\n";
|
||||
auto data = (scalar_type*)a.data_ptr();
|
||||
(void)data;
|
||||
}
|
||||
@ -73,8 +73,8 @@ TEST(TestScalar, TestScalar) {
|
||||
Scalar bar = 3.0;
|
||||
Half h = bar.toHalf();
|
||||
Scalar h2 = h;
|
||||
cout << "H2: " << h2.toDouble() << ' ' << what.toFloat() << ' '
|
||||
<< bar.toDouble() << ' ' << what.isIntegral(false) << '\n';
|
||||
cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " "
|
||||
<< bar.toDouble() << " " << what.isIntegral(false) << "\n";
|
||||
auto gen = at::detail::getDefaultCPUGenerator();
|
||||
{
|
||||
// See Note [Acquire lock when using random generators]
|
||||
@ -84,7 +84,7 @@ TEST(TestScalar, TestScalar) {
|
||||
}
|
||||
if (at::hasCUDA()) {
|
||||
auto t2 = zeros({4, 4}, at::kCUDA);
|
||||
cout << &t2 << '\n';
|
||||
cout << &t2 << "\n";
|
||||
}
|
||||
auto t = ones({4, 4});
|
||||
|
||||
@ -129,7 +129,7 @@ TEST(TestScalar, TestScalar) {
|
||||
std::stringstream ss;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
|
||||
ASSERT_NO_THROW(
|
||||
ss << "hello, dispatch" << x.toString() << s << '\n');
|
||||
ss << "hello, dispatch" << x.toString() << s << "\n");
|
||||
auto data = (scalar_t*)x.data_ptr();
|
||||
(void)data;
|
||||
});
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
int main() {
|
||||
std::cout << at::ones({3,4}, at::CPU(at::kFloat)) << '\n';
|
||||
std::cout << at::ones({3,4}, at::CPU(at::kFloat)) << "\n";
|
||||
}
|
||||
|
||||
@ -1828,9 +1828,9 @@ namespace {
|
||||
#endif
|
||||
|
||||
EXPECT_EQ(u16, c10::detail::fp16_ieee_from_fp32_value(f32s[i]))
|
||||
<< "Test failed for float to uint16 " << f32s[i] << '\n';
|
||||
<< "Test failed for float to uint16 " << f32s[i] << "\n";
|
||||
EXPECT_EQ(x, c10::detail::fp16_ieee_to_fp32_value(u16))
|
||||
<< "Test failed for uint16 to float " << u16 << '\n';
|
||||
<< "Test failed for uint16 to float " << u16 << "\n";
|
||||
}
|
||||
}
|
||||
TEST(FP8E4M3Test, FP8E4M3ConversionFloat) {
|
||||
@ -1848,10 +1848,10 @@ namespace {
|
||||
EXPECT_TRUE(std::isnan(f32));
|
||||
} else {
|
||||
EXPECT_EQ(f32, c10::detail::fp8e4m3fn_to_fp32_value(input))
|
||||
<< "Test failed for u8 to float " << input << '\n';
|
||||
<< "Test failed for u8 to float " << input << "\n";
|
||||
}
|
||||
EXPECT_EQ(u8, c10::detail::fp8e4m3fn_from_fp32_value(f32))
|
||||
<< "Test failed for float to u8 " << f32 << '\n';
|
||||
<< "Test failed for float to u8 " << f32 << "\n";
|
||||
}
|
||||
}
|
||||
TEST(FP8E4M3Test, FP8E4M3BinaryAdd) {
|
||||
@ -2015,10 +2015,10 @@ namespace {
|
||||
EXPECT_TRUE(std::isnan(f32));
|
||||
} else {
|
||||
EXPECT_EQ(f32, c10::detail::fp8e5m2_to_fp32_value(input))
|
||||
<< "Test failed for u8 to float " << input << '\n';
|
||||
<< "Test failed for u8 to float " << input << "\n";
|
||||
}
|
||||
EXPECT_EQ(u8, c10::detail::fp8e5m2_from_fp32_value(f32))
|
||||
<< "Test failed for float to u8 " << f32 << '\n';
|
||||
<< "Test failed for float to u8 " << f32 << "\n";
|
||||
}
|
||||
}
|
||||
TEST(FP8E5M2Test, FP8E5M2BinaryAdd) {
|
||||
|
||||
@ -19,7 +19,7 @@ TEST(Vitals, Basic) {
|
||||
c10::utils::set_env("TORCH_VITAL", "1");
|
||||
TORCH_VITAL_DEFINE(Testing);
|
||||
TORCH_VITAL(Testing, Attribute0) << 1;
|
||||
TORCH_VITAL(Testing, Attribute1) << '1';
|
||||
TORCH_VITAL(Testing, Attribute1) << "1";
|
||||
TORCH_VITAL(Testing, Attribute2) << 1.0f;
|
||||
TORCH_VITAL(Testing, Attribute3) << 1.0;
|
||||
auto t = at::ones({1, 1});
|
||||
|
||||
@ -129,14 +129,14 @@ void showRtol(const at::Tensor& a, const at::Tensor& b) {
|
||||
std::cout << "Max Diff allowed: " << maxDiff << std::endl;
|
||||
if (diff.sizes().size() == 2) {
|
||||
for (const auto y : c10::irange(diff.sizes()[0])) {
|
||||
std::cout << y << ':';
|
||||
std::cout << y << ":";
|
||||
for (const auto x : c10::irange(diff.sizes()[1])) {
|
||||
float diff_xy = diff[y][x].item<float>();
|
||||
if (diff_xy > maxDiff) {
|
||||
std::cout << std::setw(5) << x;
|
||||
}
|
||||
else {
|
||||
std::cout << std::setw(5) << ' ';
|
||||
std::cout << std::setw(5) << " ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
@ -3276,7 +3276,7 @@ TEST_F(VulkanAPITest, masked_fill_invalidinputs_exceptions) {
|
||||
|
||||
void print_shape(const std::vector<int64_t>& shape) {
|
||||
for (const auto& num : shape) {
|
||||
std::cout << num << ' ';
|
||||
std::cout << num << " ";
|
||||
}
|
||||
}
|
||||
|
||||
@ -3367,7 +3367,7 @@ void test_masked_fill_scalar(
|
||||
print_shape(tmp_curr_input_shape);
|
||||
std::cout << "], and mask of shape [";
|
||||
print_shape(tmp_curr_mask_shape);
|
||||
std::cout << ']' << std::endl;
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
|
||||
ASSERT_TRUE(check);
|
||||
@ -4542,9 +4542,9 @@ void test_softmax(const at::IntArrayRef shape, bool log_softmax = false) {
|
||||
if (!check) {
|
||||
std::cout << "Softmax test failed on axis " << dim << "for tensor dims {";
|
||||
for (uint32_t place = 0; place < shape.size() - 1; place++) {
|
||||
std::cout << shape[place] << ' ';
|
||||
std::cout << shape[place] << " ";
|
||||
}
|
||||
std::cout << shape.back() << '}' << std::endl;
|
||||
std::cout << shape.back() << "}" << std::endl;
|
||||
showRtol(out_cpu, out_vulkan.cpu());
|
||||
}
|
||||
ASSERT_TRUE(check);
|
||||
|
||||
@ -95,7 +95,7 @@ void showRtol(
|
||||
std::cout << "Max Diff found is: " << diff.max().item<double>() << std::endl;
|
||||
if (diff.sizes().size() == 2) {
|
||||
for (const auto y : c10::irange(diff.sizes()[0])) {
|
||||
std::cout << y << ':';
|
||||
std::cout << y << ":";
|
||||
for (const auto x : c10::irange(diff.sizes()[1])) {
|
||||
double diff_xy = diff[y][x].item<double>();
|
||||
if (diff_xy > maxDiff) {
|
||||
@ -109,7 +109,7 @@ void showRtol(
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::cout << std::setw(5) << ' ';
|
||||
std::cout << std::setw(5) << " ";
|
||||
}
|
||||
}
|
||||
std::cout << std::endl;
|
||||
@ -148,19 +148,19 @@ using at::native::vulkan::api::utils::ivec4;
|
||||
using at::native::vulkan::api::utils::vec4;
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const vec4& v) {
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ')';
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ivec3& v) {
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ')';
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const ivec4& v) {
|
||||
os << '(' << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ')';
|
||||
os << "(" << v.data[0u] << ", " << v.data[1u] << ", " << v.data[2u] << ", "
|
||||
<< v.data[3u] << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
@ -3379,51 +3379,51 @@ bool _test_quantized_linear(
|
||||
showRtol(out_cpu_dequant, out_vk_to_cpu_dequant);
|
||||
}
|
||||
if (xpos != -1 && ypos != -1) {
|
||||
std::cout << "\nFailure caused on row/col: " << ypos << '/' << xpos
|
||||
<< '\n';
|
||||
std::cout << "\nFailure caused on row/col: " << ypos << "/" << xpos
|
||||
<< "\n";
|
||||
std::cout << "Input tensor scale: " << scale << " zerop: " << zero_point
|
||||
<< '\n';
|
||||
std::cout << "Input tensor row " << ypos << '\n';
|
||||
<< "\n";
|
||||
std::cout << "Input tensor row " << ypos << "\n";
|
||||
for (int i = 0; i < input_cpu.sizes()[1]; i++) {
|
||||
std::cout << input_cpu[ypos][i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << '\n';
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "Weight tensor scale: " << w_scale
|
||||
<< " zerop: " << w_zero_point << '\n';
|
||||
std::cout << "Weight tensor col " << xpos << '\n';
|
||||
<< " zerop: " << w_zero_point << "\n";
|
||||
std::cout << "Weight tensor col " << xpos << "\n";
|
||||
for (int i = 0; i < weight.sizes()[1]; i++) {
|
||||
std::cout << weight[xpos][i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << '\n';
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "Input tensor quantized row " << ypos << " with dtype "
|
||||
<< (input_quant_dtype_int8 ? "QInt8" : "QUInt8") << '\n';
|
||||
<< (input_quant_dtype_int8 ? "QInt8" : "QUInt8") << "\n";
|
||||
for (int i = 0; i < input_cpu.sizes()[1]; i++) {
|
||||
std::cout << input_cpu_quantized[ypos][i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << '\n';
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "Weight tensor quantized col " << xpos << " with dtype "
|
||||
<< (weight_quant_dtype_int8 ? "QInt8" : "QUInt8") << '\n';
|
||||
<< (weight_quant_dtype_int8 ? "QInt8" : "QUInt8") << "\n";
|
||||
for (int i = 0; i < weight.sizes()[1]; i++) {
|
||||
std::cout << weight_cpu_quantized[xpos][i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << '\n';
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "bias tensor\n";
|
||||
for (int i = 0; i < bias.sizes()[0]; i++) {
|
||||
std::cout << bias[i].item<double>() << ", ";
|
||||
}
|
||||
std::cout << '\n';
|
||||
std::cout << "\n";
|
||||
|
||||
std::cout << "out_scale: " << out_scale
|
||||
<< " out_zero_point: " << out_zero_point << '\n';
|
||||
<< " out_zero_point: " << out_zero_point << "\n";
|
||||
|
||||
std::cout << "cpu unmatched output: "
|
||||
<< out_cpu_dequant[ypos][xpos].item<double>() << '\n';
|
||||
<< out_cpu_dequant[ypos][xpos].item<double>() << "\n";
|
||||
std::cout << "vk unmatched output: "
|
||||
<< out_vk_to_cpu_dequant[ypos][xpos].item<double>() << '\n';
|
||||
<< out_vk_to_cpu_dequant[ypos][xpos].item<double>() << "\n";
|
||||
}
|
||||
}
|
||||
return check;
|
||||
|
||||
@ -10,6 +10,13 @@
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
ignore_empty_generic_uninitialised_conditional_jump
|
||||
Memcheck:Cond
|
||||
fun:_ZN2at6detail13empty_genericEN3c108ArrayRefIlEEPNS1_9AllocatorENS1_14DispatchKeySetENS1_10ScalarTypeESt8optionalINS1_12MemoryFormatEE
|
||||
...
|
||||
}
|
||||
|
||||
{
|
||||
Cond_cuda
|
||||
Memcheck:Cond
|
||||
|
||||
@ -176,7 +176,7 @@ std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
|
||||
os << k;
|
||||
first = false;
|
||||
}
|
||||
os << ')';
|
||||
os << ")";
|
||||
return os;
|
||||
}
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ struct C10_API SafePyObject {
|
||||
(*other.pyinterpreter_)->incref(other.data_);
|
||||
}
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
}
|
||||
data_ = other.data_;
|
||||
pyinterpreter_ = other.pyinterpreter_;
|
||||
@ -53,7 +53,7 @@ struct C10_API SafePyObject {
|
||||
|
||||
~SafePyObject() {
|
||||
if (data_ != nullptr) {
|
||||
(*pyinterpreter_)->decref(data_, /*has_pyobj_slot*/ false);
|
||||
(*pyinterpreter_)->decref(data_);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -48,6 +48,30 @@ void warnDeprecatedDataPtr() {
|
||||
TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid.");
|
||||
}
|
||||
|
||||
void StorageImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void StorageImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool StorageImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
|
||||
// Allowlist verification.
|
||||
// Only if the devicetype is in the allowlist,
|
||||
|
||||
@ -105,6 +105,12 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
|
||||
data_ptr_.clear();
|
||||
}
|
||||
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
size_t nbytes() const {
|
||||
// OK to do this instead of maybe_as_int as nbytes is guaranteed positive
|
||||
TORCH_CHECK(!size_bytes_is_heap_allocated_);
|
||||
@ -370,4 +376,18 @@ C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
|
||||
bool resizable,
|
||||
std::optional<at::Device> device_opt);
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <class T>
|
||||
struct TargetTraits<
|
||||
T,
|
||||
std::enable_if_t<
|
||||
std::is_base_of_v<c10::StorageImpl, std::remove_cv_t<T>>>> {
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -277,7 +277,6 @@ void TensorImpl::release_resources() {
|
||||
if (storage_) {
|
||||
storage_ = {};
|
||||
}
|
||||
pyobj_slot_.maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
@ -989,6 +988,30 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
|
||||
}
|
||||
}
|
||||
|
||||
void TensorImpl::incref_pyobject() const {
|
||||
// Because intrusive_ptr incref uses relaxed memory order, we need to
|
||||
// do an acquire fence to ensure that the kHasPyObject bit was
|
||||
// observed before the load of the PyObject* below.
|
||||
// NB: This is a no-op on x86/x86-64
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->incref(obj);
|
||||
}
|
||||
|
||||
void TensorImpl::decref_pyobject() const {
|
||||
PyObject* obj = pyobj_slot_.load_pyobj();
|
||||
(*pyobj_slot_.pyobj_interpreter())->decref(obj);
|
||||
}
|
||||
|
||||
bool TensorImpl::try_incref_pyobject() const {
|
||||
c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter();
|
||||
if (C10_UNLIKELY(!interp)) {
|
||||
return false;
|
||||
}
|
||||
return (*interp)->try_incref(pyobj_slot_);
|
||||
}
|
||||
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -2178,6 +2178,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return &pyobj_slot_;
|
||||
}
|
||||
|
||||
void incref_pyobject() const override final;
|
||||
|
||||
void decref_pyobject() const override final;
|
||||
|
||||
bool try_incref_pyobject() const override final;
|
||||
|
||||
private:
|
||||
// See NOTE [std::optional operator usage in CUDA]
|
||||
// We probably don't want to expose this publicly until
|
||||
@ -3079,6 +3085,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
friend class C10_TensorImpl_Size_Check_Dummy_Class;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
#ifndef C10_MOBILE
|
||||
template <class T>
|
||||
struct TargetTraits<
|
||||
T,
|
||||
std::enable_if_t<std::is_base_of_v<c10::TensorImpl, std::remove_cv_t<T>>>> {
|
||||
static constexpr bool can_have_pyobject = true;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Note [TensorImpl size constraints]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Changed the size of TensorImpl? If the size went down, good for
|
||||
|
||||
@ -33,7 +33,7 @@ std::ostream& operator<<(std::ostream& stream, const TensorOptions& options) {
|
||||
} else {
|
||||
stream << "(nullopt)";
|
||||
}
|
||||
stream << ')';
|
||||
stream << ")";
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
@ -11,8 +11,11 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
|
||||
void incref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
void decref(PyObject* pyobj, bool has_pyobj_slot) const override {
|
||||
} // do nothing
|
||||
void decref(PyObject* pyobj) const override {} // do nothing
|
||||
|
||||
bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const override {
|
||||
return false;
|
||||
}
|
||||
|
||||
#define PANIC(m) \
|
||||
TORCH_INTERNAL_ASSERT( \
|
||||
@ -20,6 +23,10 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
"attempted to call " #m \
|
||||
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
|
||||
|
||||
size_t refcnt(PyObject* pyobj) const override {
|
||||
PANIC(refcnt);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
|
||||
PANIC(detach);
|
||||
}
|
||||
|
||||
@ -18,6 +18,9 @@ namespace c10 {
|
||||
struct IValue;
|
||||
class OperatorHandle;
|
||||
struct TensorImpl;
|
||||
namespace impl {
|
||||
struct PyObjectSlot;
|
||||
} // namespace impl
|
||||
} // namespace c10
|
||||
|
||||
namespace torch::jit {
|
||||
@ -126,9 +129,12 @@ struct C10_API PyInterpreterVTable {
|
||||
|
||||
// Run Py_INCREF on a PyObject.
|
||||
virtual void incref(PyObject* pyobj) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
|
||||
// See NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg]
|
||||
virtual void decref(PyObject* pyobj, bool has_pyobj_slot) const = 0;
|
||||
// Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call.
|
||||
virtual void decref(PyObject* pyobj) const = 0;
|
||||
// Run PyUnstable_TryIncRef on a PyObject if it's not NULL.
|
||||
virtual bool try_incref(const c10::impl::PyObjectSlot& pyobj_slot) const = 0;
|
||||
// Run Py_REFCNT on a PyObject.
|
||||
virtual size_t refcnt(PyObject* pyobj) const = 0;
|
||||
|
||||
// Perform a detach by deferring to the __torch_dispatch__ implementation of
|
||||
// detach, which will also arrange for the PyObject to get copied in this
|
||||
|
||||
@ -1,56 +0,0 @@
|
||||
#include <c10/core/impl/PyObjectSlot.h>
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
PyObjectSlot::PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
PyObjectSlot::~PyObjectSlot() {
|
||||
maybe_destroy_pyobj();
|
||||
}
|
||||
|
||||
void PyObjectSlot::maybe_destroy_pyobj() {
|
||||
if (owns_pyobj()) {
|
||||
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
|
||||
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
|
||||
(*pyobj_interpreter_.load(std::memory_order_acquire))
|
||||
->decref(_unchecked_untagged_pyobj(), /*has_pyobj_slot*/ true);
|
||||
// NB: this destructor can only be entered when there are no
|
||||
// references to this C++ object (obviously), NOR any references
|
||||
// to the PyObject (if there are references to the PyObject,
|
||||
// then the PyObject holds an owning reference to the tensor).
|
||||
// So it is OK to clear pyobj_ here as it is impossible for it to
|
||||
// be used again (modulo weak reference races)
|
||||
pyobj_ = nullptr; // for safety
|
||||
}
|
||||
}
|
||||
|
||||
PyInterpreter* PyObjectSlot::pyobj_interpreter() {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
PyObject* PyObjectSlot::_unchecked_untagged_pyobj() const {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<PyObject*>(
|
||||
reinterpret_cast<uintptr_t>(pyobj_) & ~0x1ULL);
|
||||
}
|
||||
|
||||
PyInterpreter& PyObjectSlot::load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter) {
|
||||
return *interpreter;
|
||||
}
|
||||
TORCH_CHECK(false, "cannot access PyObject for Tensor - no interpreter set");
|
||||
}
|
||||
|
||||
bool PyObjectSlot::owns_pyobj() {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
return reinterpret_cast<uintptr_t>(pyobj_) & 1;
|
||||
}
|
||||
|
||||
void PyObjectSlot::set_owns_pyobj(bool b) {
|
||||
// NOLINTNEXTLINE(performance-no-int-to-ptr)
|
||||
pyobj_ = reinterpret_cast<PyObject*>(
|
||||
reinterpret_cast<uintptr_t>(_unchecked_untagged_pyobj()) | b);
|
||||
}
|
||||
|
||||
} // namespace c10::impl
|
||||
@ -8,117 +8,58 @@
|
||||
|
||||
#include <atomic>
|
||||
|
||||
namespace torch::utils {
|
||||
class PyObjectPreservation;
|
||||
}
|
||||
|
||||
namespace c10::impl {
|
||||
|
||||
struct C10_API PyObjectSlot {
|
||||
public:
|
||||
PyObjectSlot();
|
||||
|
||||
~PyObjectSlot();
|
||||
|
||||
void maybe_destroy_pyobj();
|
||||
|
||||
// Associate the TensorImpl with the specified PyObject, and, if necessary,
|
||||
// also tag the interpreter.
|
||||
//
|
||||
// NB: This lives in a header so that we can inline away the switch on status
|
||||
//
|
||||
// NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
|
||||
// PyObject if necessary!
|
||||
void init_pyobj(PyObject* pyobj) {
|
||||
pyobj_interpreter_.store(
|
||||
getGlobalPyInterpreter(), std::memory_order_relaxed);
|
||||
pyobj_ = pyobj;
|
||||
}
|
||||
PyObjectSlot() : pyobj_interpreter_(nullptr), pyobj_(nullptr) {}
|
||||
|
||||
// Query the PyObject interpreter. This may return null if there is no
|
||||
// interpreter. This is racy!
|
||||
PyInterpreter* pyobj_interpreter();
|
||||
|
||||
PyObject* _unchecked_untagged_pyobj() const;
|
||||
|
||||
// Test the interpreter tag. If tagged for the current interpreter, return
|
||||
// a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
|
||||
// returns a nullopt. If it is definitely invalid, raises an error.
|
||||
//
|
||||
// If `ignore_hermetic_tls` is false and this function is called from a
|
||||
// hermetic context (ie, `HermeticPyObjectTLS::get_state()` is true), then
|
||||
// nullopt is returned. If `ignore_hermetic_tls` is true, then the hermetic
|
||||
// context is ignored, allowing you to check the interpreter tag of a
|
||||
// nonhermetic PyObject from within a hermetic context. This is necessary
|
||||
// because there are some cases where the deallocator function of a
|
||||
// nonhermetic PyObject is called from within a hermetic context, so it must
|
||||
// be properly treated as a nonhermetic PyObject.
|
||||
//
|
||||
// NB: this lives in header so that we can avoid actually creating the
|
||||
// std::optional
|
||||
|
||||
// @todo alban: I'm not too sure what's going on here, we can probably delete
|
||||
// it but it's worthwhile making sure
|
||||
std::optional<PyObject*> check_pyobj(bool ignore_hermetic_tls = false) const {
|
||||
impl::PyInterpreter* interpreter =
|
||||
pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
if (interpreter == nullptr) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (!ignore_hermetic_tls && c10::impl::HermeticPyObjectTLS::get_state()) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return _unchecked_untagged_pyobj();
|
||||
}
|
||||
// interpreter.
|
||||
PyInterpreter* pyobj_interpreter() const {
|
||||
return pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
PyInterpreter& load_pyobj_interpreter() const;
|
||||
PyInterpreter& load_pyobj_interpreter() const {
|
||||
auto interpreter = pyobj_interpreter_.load(std::memory_order_acquire);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
interpreter, "cannot access PyObject for Tensor - no interpreter set");
|
||||
return *interpreter;
|
||||
}
|
||||
|
||||
bool owns_pyobj();
|
||||
PyObject* load_pyobj() const {
|
||||
return pyobj_.load(std::memory_order_acquire);
|
||||
}
|
||||
|
||||
void set_owns_pyobj(bool b);
|
||||
void store_pyobj(PyObject* obj) {
|
||||
pyobj_.store(obj, std::memory_order_release);
|
||||
}
|
||||
|
||||
bool has_unique_reference() const {
|
||||
PyObject* pyobj = load_pyobj();
|
||||
return pyobj != nullptr && load_pyobj_interpreter()->refcnt(pyobj) == 1;
|
||||
}
|
||||
|
||||
void clear() {
|
||||
pyobj_.store(nullptr, std::memory_order_relaxed);
|
||||
pyobj_interpreter_.store(nullptr, std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
private:
|
||||
// This field contains the interpreter tag for this object. See
|
||||
// Note [Python interpreter tag] for general context
|
||||
//
|
||||
// Note [Memory ordering on Python interpreter tag]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// What memory_order do we need when accessing this atomic? We don't
|
||||
// need a single total modification order (as provided by
|
||||
// memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
|
||||
// transition from -1 to some positive integer and never changes afterwards.
|
||||
// Because there is only one modification, it trivially already has a total
|
||||
// modification order (e.g., we don't need fences or locked instructions on
|
||||
// x86)
|
||||
//
|
||||
// In fact, one could make a reasonable argument that relaxed reads are OK,
|
||||
// due to the presence of external locking (GIL) to ensure that interactions
|
||||
// with other data structures are still correctly synchronized, so that
|
||||
// we fall in the "Single-Location Data Structures" case as described in
|
||||
// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
|
||||
// However, on x86, it doesn't matter if I use acquire or relaxed on the load
|
||||
// as I get the same assembly in both cases. So I just use the more
|
||||
// conservative acquire (which will impede compiler optimizations but I don't
|
||||
// care)
|
||||
// This is now always the global interpreter if the PyObject is set.
|
||||
// Maybe we can remove this field some day...
|
||||
std::atomic<PyInterpreter*> pyobj_interpreter_;
|
||||
|
||||
// This field contains a reference to a PyObject representing this Tensor.
|
||||
// If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
|
||||
// PyObject for it and set this field. This field does not have to be
|
||||
// protected by an atomic as it is only allowed to be accessed when you hold
|
||||
// the GIL, or during destruction of the tensor.
|
||||
//
|
||||
// When a PyObject dies, you are obligated to clear this field
|
||||
// (otherwise, you will try to use-after-free the pyobj); this currently
|
||||
// occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
|
||||
//
|
||||
// NB: Ordinarily, this should not be a strong reference, as if the
|
||||
// PyObject owns the Tensor, this would create a reference cycle.
|
||||
// However, sometimes this ownership flips. To track who owns
|
||||
// who, this has a single pointer tag indicating whether or not the
|
||||
// C++ object owns the PyObject (the common case, zero, means PyObject
|
||||
// owns the C++ object); see _unchecked_untagged_pyobj for raw access
|
||||
// or check_pyobj for checked access. See references to PyObject
|
||||
// resurrection in torch/csrc/autograd/python_variable.cpp
|
||||
PyObject* pyobj_;
|
||||
// The PyObject representing this Tensor or nullptr. Ownership is managed
|
||||
// by intrusive_ptr. By the time the PyObjectSlot is destroyed, this
|
||||
// reference is already dead.
|
||||
std::atomic<PyObject*> pyobj_;
|
||||
|
||||
friend class torch::utils::PyObjectPreservation;
|
||||
};
|
||||
|
||||
} // namespace c10::impl
|
||||
|
||||
@ -136,7 +136,7 @@ std::string c10_retrieve_device_side_assertion_info() {
|
||||
// Something failed, let's talk about that
|
||||
oss << failures_found
|
||||
<< " CUDA device-side assertion failures were found on GPU #"
|
||||
<< device_num << '!' << std::endl;
|
||||
<< device_num << "!" << std::endl;
|
||||
if (assertion_data_for_device.assertion_count >
|
||||
C10_CUDA_DSA_ASSERTION_COUNT) {
|
||||
oss << "But at least " << assertion_data_for_device.assertion_count
|
||||
@ -151,17 +151,17 @@ std::string c10_retrieve_device_side_assertion_info() {
|
||||
oss << "Assertion failure " << i << std::endl;
|
||||
oss << " GPU assertion failure message = " << self.assertion_msg
|
||||
<< std::endl;
|
||||
oss << " File containing assertion = " << self.filename << ':'
|
||||
oss << " File containing assertion = " << self.filename << ":"
|
||||
<< self.line_number << std::endl;
|
||||
oss << " Device function containing assertion = " << self.function_name
|
||||
<< std::endl;
|
||||
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ','
|
||||
<< self.thread_id[1] << ',' << self.thread_id[2] << ']' << std::endl;
|
||||
oss << " Block ID that failed assertion = [" << self.block_id[0] << ','
|
||||
<< self.block_id[1] << ',' << self.block_id[2] << ']' << std::endl;
|
||||
oss << " Thread ID that failed assertion = [" << self.thread_id[0] << ","
|
||||
<< self.thread_id[1] << "," << self.thread_id[2] << "]" << std::endl;
|
||||
oss << " Block ID that failed assertion = [" << self.block_id[0] << ","
|
||||
<< self.block_id[1] << "," << self.block_id[2] << "]" << std::endl;
|
||||
if (launch_info.generation_number == self.caller) {
|
||||
oss << " File containing kernel launch = "
|
||||
<< launch_info.launch_filename << ':' << launch_info.launch_linenum
|
||||
<< launch_info.launch_filename << ":" << launch_info.launch_linenum
|
||||
<< std::endl;
|
||||
oss << " Function containing kernel launch = "
|
||||
<< launch_info.launch_function << std::endl;
|
||||
@ -175,7 +175,7 @@ std::string c10_retrieve_device_side_assertion_info() {
|
||||
if (launch_registry.gather_launch_stacktrace) {
|
||||
oss << "Launch stacktracing disabled." << std::endl;
|
||||
} else {
|
||||
oss << '\n' << launch_info.launch_stacktrace << std::endl;
|
||||
oss << "\n" << launch_info.launch_stacktrace << std::endl;
|
||||
}
|
||||
} else {
|
||||
oss << " CPU launch site info: Unavailable, the circular queue wrapped around. Increase `CUDAKernelLaunchRegistry::max_size`."
|
||||
|
||||
@ -435,7 +435,7 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
|
||||
if (i > 0) {
|
||||
ASSERT_TRUE(res.find("Unknown") == std::string::npos)
|
||||
<< i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
|
||||
<< ')';
|
||||
<< ")";
|
||||
} else {
|
||||
ASSERT_TRUE(res.find("Unknown") == std::string::npos) << i;
|
||||
}
|
||||
|
||||
@ -96,10 +96,10 @@ TEST(HalfConversionTest, TestPorableConversion) {
|
||||
for (auto x : inputs) {
|
||||
auto target = c10::detail::fp16_ieee_to_fp32_value(x);
|
||||
EXPECT_EQ(halfbits2float(x), target)
|
||||
<< "Test failed for uint16 to float " << x << '\n';
|
||||
<< "Test failed for uint16 to float " << x << "\n";
|
||||
EXPECT_EQ(
|
||||
float2halfbits(target), c10::detail::fp16_ieee_from_fp32_value(target))
|
||||
<< "Test failed for float to uint16" << target << '\n';
|
||||
<< "Test failed for float to uint16" << target << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ struct Noncopyable {
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Noncopyable& nc) {
|
||||
out << "Noncopyable(" << nc.x << ')';
|
||||
out << "Noncopyable(" << nc.x << ")";
|
||||
return out;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user