mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
Compare commits
128 Commits
viable/str
...
cpp-docs-d
| Author | SHA1 | Date | |
|---|---|---|---|
| df1268c311 | |||
| 84f9f1541d | |||
| 27c0c126bf | |||
| 670873155a | |||
| 923737c510 | |||
| 13d5b14a73 | |||
| a35a42b21c | |||
| 15956bc1e8 | |||
| b319ea1111 | |||
| ce4c68a5f6 | |||
| c6da4a59a3 | |||
| 53f75cd5ba | |||
| 527b1109a8 | |||
| 3144713325 | |||
| eefa16342c | |||
| d02f68f484 | |||
| 68eb55c4b2 | |||
| 8d4b8ab430 | |||
| afd50bdd29 | |||
| 56dfd4c74b | |||
| 24db5c4451 | |||
| cc8bfd1206 | |||
| c45b156605 | |||
| 8fff7e36b4 | |||
| 82fa2aa269 | |||
| 09e0285608 | |||
| d980d8dc79 | |||
| c7d00de115 | |||
| d3cf90ada5 | |||
| 0e1a88904f | |||
| 3232caa078 | |||
| a6c6acea9d | |||
| 55be1cc739 | |||
| 344cebda52 | |||
| ba72c6b981 | |||
| 888efcc453 | |||
| 24aa9a2ef7 | |||
| f70faf2b9a | |||
| 167e64ba1a | |||
| 875b18d53c | |||
| eec3749c44 | |||
| 40133fe966 | |||
| f288433d3e | |||
| 864633fca0 | |||
| c21868b435 | |||
| a0a8eca01a | |||
| 0958f307d9 | |||
| 7551507c41 | |||
| f92834d477 | |||
| e1fc01bef8 | |||
| 22a745737a | |||
| ee708ea96c | |||
| 64819e3701 | |||
| 79ff2c66c8 | |||
| 665a411351 | |||
| 5c89bdb461 | |||
| 7b64ad906c | |||
| d944279def | |||
| 5048e4701d | |||
| 616314cfd5 | |||
| 2b7e4c3ef2 | |||
| 6c98657239 | |||
| 86b2d82e84 | |||
| eea8ff2d34 | |||
| 11f73d78c8 | |||
| 7d1b976146 | |||
| 27cfdd9e77 | |||
| 01d8d8584b | |||
| b8855e7b0b | |||
| 6725ee89c8 | |||
| 3a38ec78e1 | |||
| 77b9399d83 | |||
| 83cd626365 | |||
| 5125872aeb | |||
| c10975d2e6 | |||
| 68e31e2f81 | |||
| ee1bc3f0d5 | |||
| 612ead1619 | |||
| 3af1f7bbf4 | |||
| 71a2e93547 | |||
| c76199980d | |||
| e3bd7bd1f4 | |||
| aa4a8c9b92 | |||
| fa0fd6be13 | |||
| 2f3f88f445 | |||
| d67d807270 | |||
| bcad4f2e68 | |||
| 5b17ef30d0 | |||
| 7b2992685b | |||
| f3fa560dec | |||
| 984b096d10 | |||
| 104b868618 | |||
| 94f2657c4b | |||
| 3f6538febd | |||
| f33abae695 | |||
| 73da7a40b6 | |||
| 335b5c7d4b | |||
| 76bb27e248 | |||
| a2da69385a | |||
| d177900723 | |||
| 61bcc8d75a | |||
| 1656b253c5 | |||
| 5d6230779d | |||
| a4077b568f | |||
| ae038f871b | |||
| defac66e39 | |||
| 061fa73c97 | |||
| 9501405de6 | |||
| e0791fc11d | |||
| e1d011d6eb | |||
| 3f5401020b | |||
| 5a3930abbc | |||
| a5f00077fc | |||
| 69fb3ebb5d | |||
| 1c4ced2eaf | |||
| 392acee68a | |||
| fee1ac927d | |||
| 4a7fefd7c7 | |||
| 3b4315940d | |||
| 3eddf04922 | |||
| 7c203b8420 | |||
| 3ca216ae17 | |||
| 9c22bbb2dc | |||
| 6268883f9c | |||
| 16212f0d6b | |||
| c8adc08b3b | |||
| 23b57a445c | |||
| 6c7cad6972 |
@ -13,3 +13,4 @@ exclude:
|
||||
- "**/benchmarks/**"
|
||||
- "**/test_*.py"
|
||||
- "**/*_test.py"
|
||||
- "tools/**"
|
||||
|
||||
@ -149,7 +149,7 @@ FROM cpu_final as rocm_final
|
||||
ARG ROCM_VERSION=6.0
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||
ARG DEVTOOLSET_VERSION=11
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV LDFLAGS="-Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64 -Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib"
|
||||
# Somewhere in ROCm stack, we still use non-existing /opt/rocm/hip path,
|
||||
# below workaround helps avoid error
|
||||
|
||||
@ -1,15 +1,11 @@
|
||||
sphinx==5.3.0
|
||||
sphinx==7.2.6
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
#Pinned versions: 7.2.6
|
||||
|
||||
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||
# The reasons are as follows:
|
||||
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||
pytorch_sphinx_theme2==0.2.0
|
||||
#Description: This is needed to generate PyTorch docs
|
||||
#Pinned versions: 0.2.0
|
||||
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
# something related to Docker setup. We can investigate this later.
|
||||
@ -36,17 +32,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 2.13.0
|
||||
|
||||
breathe==4.34.0
|
||||
breathe==4.36.0
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 4.34.0
|
||||
#Pinned versions: 4.36.0
|
||||
|
||||
exhale==0.2.3
|
||||
exhale==0.3.7
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.2.3
|
||||
#Pinned versions: 0.3.7
|
||||
|
||||
docutils==0.16
|
||||
docutils==0.20
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
#Pinned versions: 0.16
|
||||
#Pinned versions: 0.20
|
||||
|
||||
bs4==0.0.1
|
||||
#Description: This is used to generate PyTorch C++ docs
|
||||
@ -56,13 +52,13 @@ IPython==8.12.0
|
||||
#Description: This is used to generate PyTorch functorch docs
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
myst-nb==1.3.0
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
#Pinned versions: 1.3.0
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
python-etcd==0.4.5
|
||||
sphinx-copybutton==0.5.0
|
||||
sphinx-design==0.4.0
|
||||
sphinx-design==0.6.1
|
||||
sphinxcontrib-mermaid==1.0.0
|
||||
myst-parser==0.18.1
|
||||
myst-parser==4.0.1
|
||||
|
||||
@ -89,23 +89,39 @@ if [ "$is_main_doc" = true ]; then
|
||||
|
||||
make coverage
|
||||
# Now we have the coverage report, we need to make sure it is empty.
|
||||
# Count the number of lines in the file and turn that number into a variable
|
||||
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
|
||||
# Skip the report header by subtracting 2: the header will be output even if
|
||||
# there are no undocumented items.
|
||||
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
|
||||
# showing the undocumented count in the third column.
|
||||
# Example: | TOTAL | 99.83% | 2 |
|
||||
#
|
||||
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
|
||||
# be documented then removed from there.
|
||||
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
|
||||
undocumented=$((lines - 2))
|
||||
if [ $undocumented -lt 0 ]; then
|
||||
|
||||
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
|
||||
# The table format is: | Module | Coverage | Undocumented |
|
||||
# Extract the third column (undocumented count) from the TOTAL row
|
||||
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
|
||||
|
||||
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
|
||||
echo coverage output not found
|
||||
exit 1
|
||||
elif [ $undocumented -gt 0 ]; then
|
||||
echo undocumented objects found:
|
||||
cat build/coverage/python.txt
|
||||
elif [ "$undocumented" -gt 0 ]; then
|
||||
echo ""
|
||||
echo "====================="
|
||||
echo "UNDOCUMENTED OBJECTS:"
|
||||
echo "====================="
|
||||
echo ""
|
||||
# Find the line number of the TOTAL row and print only what comes after it
|
||||
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
|
||||
if [ -n "$total_line" ]; then
|
||||
# Print only the detailed list (skip the statistics table)
|
||||
tail -n +$((total_line + 2)) build/coverage/python.txt
|
||||
else
|
||||
# Fallback to showing entire file if TOTAL line not found
|
||||
cat build/coverage/python.txt
|
||||
fi
|
||||
echo ""
|
||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
|
||||
@ -337,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
@ -1653,7 +1653,7 @@ test_operator_microbenchmark() {
|
||||
|
||||
cd "${TEST_DIR}"/benchmarks/operator_benchmark
|
||||
|
||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do
|
||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do
|
||||
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
|
||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
|
||||
--benchmark-name "PyTorch operator microbenchmark" --use-compile
|
||||
|
||||
@ -60,9 +60,11 @@ performance-*,
|
||||
readability-container-size-empty,
|
||||
readability-delete-null-pointer,
|
||||
readability-duplicate-include,
|
||||
readability-named-parameter,
|
||||
readability-misplaced-array-index,
|
||||
readability-redundant*,
|
||||
readability-simplify-subscript-expr,
|
||||
readability-static-definition-in-anonymous-namespace
|
||||
readability-string-compare,
|
||||
-readability-redundant-access-specifiers,
|
||||
-readability-redundant-control-flow,
|
||||
|
||||
319
.claude/skills/add-uint-support/SKILL.md
Normal file
319
.claude/skills/add-uint-support/SKILL.md
Normal file
@ -0,0 +1,319 @@
|
||||
---
|
||||
name: add-uint-support
|
||||
description: Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support.
|
||||
---
|
||||
|
||||
# Add Unsigned Integer (uint) Support to Operators
|
||||
|
||||
This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros.
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use this skill when:
|
||||
- Adding uint16, uint32, or uint64 support to an operator
|
||||
- User mentions "unsigned types", "uint support", "barebones unsigned types"
|
||||
- Enabling support for kUInt16, kUInt32, kUInt64 in kernels
|
||||
- Working with operator implementations that need expanded type coverage
|
||||
|
||||
## Quick reference
|
||||
|
||||
**Add unsigned types to existing dispatch:**
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES));
|
||||
|
||||
// After (method 1: add unsigned types explicitly)
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
||||
|
||||
// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
```
|
||||
|
||||
## Type group reference
|
||||
|
||||
**Unsigned type groups:**
|
||||
- `AT_BAREBONES_UNSIGNED_TYPES`: kUInt16, kUInt32, kUInt64
|
||||
- `AT_INTEGRAL_TYPES_V2`: AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
|
||||
|
||||
**Relationship:**
|
||||
```cpp
|
||||
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
|
||||
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
|
||||
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES
|
||||
```
|
||||
|
||||
## Instructions
|
||||
|
||||
### Step 1: Determine if conversion to V2 is needed
|
||||
|
||||
Check if the file uses AT_DISPATCH_V2:
|
||||
|
||||
**If using old AT_DISPATCH:**
|
||||
- First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill
|
||||
- Then proceed with adding uint support
|
||||
|
||||
**If already using AT_DISPATCH_V2:**
|
||||
- Proceed directly to Step 2
|
||||
|
||||
### Step 2: Analyze the current dispatch macro
|
||||
|
||||
Identify what type groups are currently in use:
|
||||
|
||||
```cpp
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
// body
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Current type coverage
|
||||
```
|
||||
|
||||
Common patterns:
|
||||
- `AT_EXPAND(AT_ALL_TYPES)` → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
|
||||
- `AT_EXPAND(AT_INTEGRAL_TYPES)` → signed integers only
|
||||
- `AT_EXPAND(AT_FLOATING_TYPES)` → floating point types
|
||||
|
||||
### Step 3: Choose the uint addition method
|
||||
|
||||
Two approaches:
|
||||
|
||||
**Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly**
|
||||
- Use when: You want to be explicit about adding uint support
|
||||
- Add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the type list
|
||||
|
||||
**Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2**
|
||||
- Use when: The dispatch already uses `AT_EXPAND(AT_INTEGRAL_TYPES)`
|
||||
- More concise: replaces one type group with its superset
|
||||
- Only applicable if AT_INTEGRAL_TYPES is present
|
||||
|
||||
### Step 4: Apply the transformation
|
||||
|
||||
**Method 1 example:**
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"min_values_cuda",
|
||||
AT_WRAP([&]() {
|
||||
kernel_impl<scalar_t>(iter);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES),
|
||||
kBFloat16, kHalf, kBool
|
||||
);
|
||||
|
||||
// After (add unsigned types)
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"min_values_cuda",
|
||||
AT_WRAP([&]() {
|
||||
kernel_impl<scalar_t>(iter);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES),
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||
kBFloat16, kHalf, kBool
|
||||
);
|
||||
```
|
||||
|
||||
**Method 2 example:**
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"integral_op",
|
||||
AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}),
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES)
|
||||
);
|
||||
|
||||
// After (substitute with V2)
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"integral_op",
|
||||
AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}),
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES_V2)
|
||||
);
|
||||
```
|
||||
|
||||
### Step 5: Handle AT_ALL_TYPES vs individual type groups
|
||||
|
||||
If the dispatch uses `AT_EXPAND(AT_ALL_TYPES)`:
|
||||
- `AT_ALL_TYPES` = `AT_INTEGRAL_TYPES` + `AT_FLOATING_TYPES`
|
||||
- To add uint: add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the list
|
||||
|
||||
If the dispatch separately lists INTEGRAL and FLOATING:
|
||||
```cpp
|
||||
// Before
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
|
||||
// After (Method 2 preferred)
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
```
|
||||
|
||||
### Step 6: Verify all dispatch sites
|
||||
|
||||
Check the file for ALL dispatch macros that need uint support:
|
||||
- Some operators have multiple dispatch sites (CPU, CUDA, different functions)
|
||||
- Apply the transformation consistently across all sites
|
||||
- Ensure each gets the same type coverage updates
|
||||
|
||||
### Step 7: Validate the changes
|
||||
|
||||
Check that:
|
||||
- [ ] AT_DISPATCH_V2 format is used (not old AT_DISPATCH)
|
||||
- [ ] Unsigned types are added via one of the two methods
|
||||
- [ ] All relevant dispatch sites in the file are updated
|
||||
- [ ] Type groups use `AT_EXPAND()`
|
||||
- [ ] Arguments are properly formatted and comma-separated
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Pattern 1: AT_ALL_TYPES + extras
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
```
|
||||
|
||||
### Pattern 2: Separate INTEGRAL + FLOATING
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
|
||||
```
|
||||
|
||||
### Pattern 3: Old dispatch needs conversion first
|
||||
|
||||
```cpp
|
||||
// Before (needs v2 conversion first)
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
|
||||
kernel<scalar_t>();
|
||||
});
|
||||
|
||||
// After v2 conversion
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
||||
|
||||
// After adding uint support
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
||||
```
|
||||
|
||||
## Multiple dispatch sites example
|
||||
|
||||
For a file with multiple functions:
|
||||
|
||||
```cpp
|
||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
|
||||
impl<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
||||
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
// Added uint support
|
||||
}
|
||||
|
||||
void min_launch_kernel(TensorIterator &iter) {
|
||||
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
|
||||
gpu_reduce_kernel<scalar_t>(iter);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
||||
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
// Added uint support here too
|
||||
}
|
||||
```
|
||||
|
||||
## Decision tree
|
||||
|
||||
Use this decision tree to determine the approach:
|
||||
|
||||
```
|
||||
Is the file using AT_DISPATCH_V2?
|
||||
├─ No → Use at-dispatch-v2 skill first, then continue
|
||||
└─ Yes
|
||||
└─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)?
|
||||
├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2)
|
||||
└─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list
|
||||
```
|
||||
|
||||
## Edge cases
|
||||
|
||||
### Case 1: Dispatch with only floating types
|
||||
|
||||
If the operator only supports floating point types, don't add uint support:
|
||||
|
||||
```cpp
|
||||
// Leave as-is - floating point only operator
|
||||
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);
|
||||
```
|
||||
|
||||
### Case 2: Complex types present
|
||||
|
||||
Unsigned types work alongside complex types:
|
||||
|
||||
```cpp
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES),
|
||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
||||
AT_EXPAND(AT_COMPLEX_TYPES),
|
||||
kHalf, kBFloat16);
|
||||
```
|
||||
|
||||
### Case 3: Already has uint support
|
||||
|
||||
Check if uint types are already present:
|
||||
- If `AT_INTEGRAL_TYPES_V2` is used → already has uint support
|
||||
- If `AT_BAREBONES_UNSIGNED_TYPES` is already in list → already has uint support
|
||||
- Skip the file if uint support is already present
|
||||
|
||||
## Workflow
|
||||
|
||||
When asked to add uint support:
|
||||
|
||||
1. Read the target file
|
||||
2. Check if using AT_DISPATCH_V2:
|
||||
- If not → use at-dispatch-v2 skill first
|
||||
3. Identify all dispatch macro sites
|
||||
4. For each dispatch:
|
||||
- Analyze current type groups
|
||||
- Choose method (add BAREBONES_UNSIGNED or upgrade to V2)
|
||||
- Apply transformation with Edit tool
|
||||
5. Show the user the changes
|
||||
6. Explain what was modified
|
||||
|
||||
## Important notes
|
||||
|
||||
- Always check if v2 conversion is needed first
|
||||
- Apply changes consistently across all dispatch sites in the file
|
||||
- Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable
|
||||
- Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit
|
||||
- Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8)
|
||||
- Some operators may not semantically support unsigned types - use judgment
|
||||
|
||||
## Testing
|
||||
|
||||
After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing.
|
||||
305
.claude/skills/at-dispatch-v2/SKILL.md
Normal file
305
.claude/skills/at-dispatch-v2/SKILL.md
Normal file
@ -0,0 +1,305 @@
|
||||
---
|
||||
name: at-dispatch-v2
|
||||
description: Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations.
|
||||
---
|
||||
|
||||
# AT_DISPATCH to AT_DISPATCH_V2 Converter
|
||||
|
||||
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in `aten/src/ATen/Dispatch_v2.h`.
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use this skill when:
|
||||
- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
|
||||
- Porting ATen kernels to use the new dispatch API
|
||||
- Working with files in `aten/src/ATen/native/` that use dispatch macros
|
||||
- User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
|
||||
|
||||
## Quick reference
|
||||
|
||||
**Old format:**
|
||||
```cpp
|
||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
|
||||
// lambda body
|
||||
});
|
||||
```
|
||||
|
||||
**New format:**
|
||||
```cpp
|
||||
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
|
||||
// lambda body
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
|
||||
```
|
||||
|
||||
## Key transformations
|
||||
|
||||
1. **Reorder arguments**: `scalar_type` and `name` come first, then lambda, then types
|
||||
2. **Wrap the lambda**: Use `AT_WRAP(lambda)` to handle internal commas
|
||||
3. **Expand type groups**: Use `AT_EXPAND(AT_ALL_TYPES)` instead of implicit expansion
|
||||
4. **List individual types**: Add extra types (kHalf, kBFloat16, etc.) after expanded groups
|
||||
5. **Add include**: `#include <ATen/Dispatch_v2.h>` near other Dispatch includes
|
||||
|
||||
## Instructions
|
||||
|
||||
### Step 1: Add the Dispatch_v2.h include
|
||||
|
||||
Add the v2 header near the existing `#include <ATen/Dispatch.h>`:
|
||||
|
||||
```cpp
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
```
|
||||
|
||||
Keep the old Dispatch.h include for now (other code may still need it).
|
||||
|
||||
### Step 2: Identify the old dispatch pattern
|
||||
|
||||
Common patterns to convert:
|
||||
|
||||
- `AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)`
|
||||
- `AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)`
|
||||
- `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)`
|
||||
- `AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)`
|
||||
|
||||
### Step 3: Map the old macro to type groups
|
||||
|
||||
Identify which type group macro corresponds to the base types:
|
||||
|
||||
| Old macro base | AT_DISPATCH_V2 type group |
|
||||
|----------------|---------------------------|
|
||||
| `ALL_TYPES` | `AT_EXPAND(AT_ALL_TYPES)` |
|
||||
| `FLOATING_TYPES` | `AT_EXPAND(AT_FLOATING_TYPES)` |
|
||||
| `INTEGRAL_TYPES` | `AT_EXPAND(AT_INTEGRAL_TYPES)` |
|
||||
| `COMPLEX_TYPES` | `AT_EXPAND(AT_COMPLEX_TYPES)` |
|
||||
| `ALL_TYPES_AND_COMPLEX` | `AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)` |
|
||||
|
||||
For combined patterns, use multiple `AT_EXPAND()` entries:
|
||||
```cpp
|
||||
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
|
||||
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
|
||||
```
|
||||
|
||||
### Step 4: Extract the individual types
|
||||
|
||||
From `AT_DISPATCH_*_AND2(type1, type2, ...)` or `AT_DISPATCH_*_AND3(type1, type2, type3, ...)`, extract the individual types (type1, type2, etc.).
|
||||
|
||||
These become the trailing arguments after the type group:
|
||||
```cpp
|
||||
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Individual types from AND3
|
||||
```
|
||||
|
||||
### Step 5: Transform to AT_DISPATCH_V2
|
||||
|
||||
Apply the transformation:
|
||||
|
||||
**Pattern:**
|
||||
```cpp
|
||||
AT_DISPATCH_V2(
|
||||
scalar_type, // 1st: The dtype expression
|
||||
"name", // 2nd: The debug string
|
||||
AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP
|
||||
type_groups, // 4th+: Type groups with AT_EXPAND()
|
||||
individual_types // Last: Individual types
|
||||
)
|
||||
```
|
||||
|
||||
**Example transformation:**
|
||||
```cpp
|
||||
// BEFORE
|
||||
AT_DISPATCH_ALL_TYPES_AND3(
|
||||
kBFloat16, kHalf, kBool,
|
||||
iter.dtype(),
|
||||
"min_values_cuda",
|
||||
[&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}
|
||||
);
|
||||
|
||||
// AFTER
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(),
|
||||
"min_values_cuda",
|
||||
AT_WRAP([&]() {
|
||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES),
|
||||
kBFloat16, kHalf, kBool
|
||||
);
|
||||
```
|
||||
|
||||
### Step 6: Handle multi-line lambdas
|
||||
|
||||
For lambdas with internal commas or complex expressions, AT_WRAP is essential:
|
||||
|
||||
```cpp
|
||||
AT_DISPATCH_V2(
|
||||
dtype,
|
||||
"complex_kernel",
|
||||
AT_WRAP([&]() {
|
||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
||||
iter,
|
||||
MinOps<scalar_t>{},
|
||||
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // Commas inside!
|
||||
);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES)
|
||||
);
|
||||
```
|
||||
|
||||
### Step 7: Verify the conversion
|
||||
|
||||
Check that:
|
||||
- [ ] `AT_WRAP()` wraps the entire lambda
|
||||
- [ ] Type groups use `AT_EXPAND()`
|
||||
- [ ] Individual types don't have `AT_EXPAND()` (just `kBFloat16`, not `AT_EXPAND(kBFloat16)`)
|
||||
- [ ] Argument order is: scalar_type, name, lambda, types
|
||||
- [ ] Include added: `#include <ATen/Dispatch_v2.h>`
|
||||
|
||||
## Type group reference
|
||||
|
||||
Available type group macros (use with `AT_EXPAND()`):
|
||||
|
||||
```cpp
|
||||
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
|
||||
AT_FLOATING_TYPES // kDouble, kFloat
|
||||
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
|
||||
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
|
||||
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
|
||||
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
|
||||
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types
|
||||
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
|
||||
AT_FLOAT8_TYPES // Float8 variants
|
||||
```
|
||||
|
||||
## Common patterns
|
||||
|
||||
### Pattern: AT_DISPATCH_ALL_TYPES_AND2
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
|
||||
kernel<scalar_t>(data);
|
||||
});
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>(data);
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
||||
```
|
||||
|
||||
### Pattern: AT_DISPATCH_FLOATING_TYPES_AND3
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
|
||||
tensor.scalar_type(), "float_op", [&] {
|
||||
process<scalar_t>(tensor);
|
||||
});
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
|
||||
process<scalar_t>(tensor);
|
||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
|
||||
```
|
||||
|
||||
### Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
||||
kComplexHalf, kHalf,
|
||||
self.scalar_type(),
|
||||
"complex_op",
|
||||
[&] {
|
||||
result = compute<scalar_t>(self);
|
||||
}
|
||||
);
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(
|
||||
self.scalar_type(),
|
||||
"complex_op",
|
||||
AT_WRAP([&] {
|
||||
result = compute<scalar_t>(self);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES),
|
||||
AT_EXPAND(AT_COMPLEX_TYPES),
|
||||
kComplexHalf,
|
||||
kHalf
|
||||
);
|
||||
```
|
||||
|
||||
## Edge cases
|
||||
|
||||
### Case 1: No extra types (rare)
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES));
|
||||
```
|
||||
|
||||
### Case 2: Many individual types (AND4, AND5, etc.)
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
|
||||
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
|
||||
kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
|
||||
```
|
||||
|
||||
### Case 3: Lambda with no captures
|
||||
|
||||
```cpp
|
||||
// Before
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
|
||||
static_kernel<scalar_t>();
|
||||
});
|
||||
|
||||
// After
|
||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
|
||||
static_kernel<scalar_t>();
|
||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
|
||||
```
|
||||
|
||||
## Benefits of AT_DISPATCH_V2
|
||||
|
||||
1. **No arity in macro name**: Don't need different macros for AND2, AND3, AND4
|
||||
2. **Composable type sets**: Mix and match type groups with `AT_EXPAND()`
|
||||
3. **Extensible**: Easy to add more types without hitting macro limits
|
||||
4. **Clearer**: Type groups are explicit, not implicit in macro name
|
||||
|
||||
## Important notes
|
||||
|
||||
- Keep `#include <ATen/Dispatch.h>` - other code may need it
|
||||
- The `AT_WRAP()` is mandatory - prevents comma parsing issues in the lambda
|
||||
- Type groups need `AT_EXPAND()`, individual types don't
|
||||
- The v2 API is in `aten/src/ATen/Dispatch_v2.h` - refer to it for full docs
|
||||
- See the header file for the Python script to regenerate the macro implementation
|
||||
|
||||
## Workflow
|
||||
|
||||
When asked to convert AT_DISPATCH macros:
|
||||
|
||||
1. Read the file to identify all AT_DISPATCH uses
|
||||
2. Add `#include <ATen/Dispatch_v2.h>` if not present
|
||||
3. For each dispatch macro:
|
||||
- Identify the pattern and extract components
|
||||
- Map the base type group
|
||||
- Extract individual types
|
||||
- Construct the AT_DISPATCH_V2 call
|
||||
- Apply with Edit tool
|
||||
4. Show the user the complete converted file
|
||||
5. Explain what was changed
|
||||
|
||||
Do NOT compile or test the code - focus on accurate conversion only.
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
df6798dfb931ce7c7fe5bed2447cd1092a5981af
|
||||
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
|
||||
|
||||
@ -28,7 +28,7 @@ CUDA_ARCHES_FULL_VERSION = {
|
||||
"12.6": "12.6.3",
|
||||
"12.8": "12.8.1",
|
||||
"12.9": "12.9.1",
|
||||
"13.0": "13.0.2",
|
||||
"13.0": "13.0.0",
|
||||
}
|
||||
CUDA_ARCHES_CUDNN_VERSION = {
|
||||
"12.6": "9",
|
||||
|
||||
1
.github/workflows/docker-release.yml
vendored
1
.github/workflows/docker-release.yml
vendored
@ -8,6 +8,7 @@ on:
|
||||
- docker.Makefile
|
||||
- .github/workflows/docker-release.yml
|
||||
- .github/scripts/generate_docker_release_matrix.py
|
||||
- .github/scripts/generate_binary_build_matrix.py
|
||||
push:
|
||||
branches:
|
||||
- nightly
|
||||
|
||||
3
.github/workflows/inductor-rocm.yml
vendored
3
.github/workflows/inductor-rocm.yml
vendored
@ -1,9 +1,10 @@
|
||||
name: inductor-rocm
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 * * * *
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/inductor-rocm/*
|
||||
|
||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||
]}
|
||||
build-additional-packages: "vision audio torchao"
|
||||
|
||||
15
.github/workflows/lint.yml
vendored
15
.github/workflows/lint.yml
vendored
@ -76,11 +76,12 @@ jobs:
|
||||
|
||||
# NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes
|
||||
# fails to find types when it should
|
||||
lintrunner-mypy:
|
||||
# NOTE: We should be able to disable this and consolidate with Pyrefly
|
||||
lintrunner-pyrefly:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
||||
name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
|
||||
name: lintrunner-pyrefly-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
|
||||
needs: [get-label-type, get-changed-files]
|
||||
# Only run if there are changed files relevant to mypy
|
||||
# Only run if there are changed files relevant to pyrefly
|
||||
if: |
|
||||
github.repository_owner == 'pytorch' && (
|
||||
needs.get-changed-files.outputs.changed-files == '*' ||
|
||||
@ -98,8 +99,8 @@ jobs:
|
||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||
script: |
|
||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||
echo "Running mypy"
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
|
||||
echo "Running pyrefly"
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--take PYREFLY --all-files" .github/scripts/lintrunner.sh
|
||||
|
||||
lintrunner-noclang:
|
||||
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
||||
@ -118,9 +119,9 @@ jobs:
|
||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||
echo "Running all other linters"
|
||||
if [ "$CHANGED_FILES" = '*' ]; then
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY --all-files" .github/scripts/lintrunner.sh
|
||||
else
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||
fi
|
||||
|
||||
quick-checks:
|
||||
|
||||
2
.github/workflows/nightly.yml
vendored
2
.github/workflows/nightly.yml
vendored
@ -41,7 +41,7 @@ jobs:
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py3.10-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
|
||||
secrets: inherit
|
||||
|
||||
8
.github/workflows/pull.yml
vendored
8
.github/workflows/pull.yml
vendored
@ -66,10 +66,10 @@ jobs:
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
@ -167,8 +167,8 @@ jobs:
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
2
.github/workflows/rocm.yml
vendored
2
.github/workflows/rocm.yml
vendored
@ -3,13 +3,13 @@ name: rocm
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/rocm/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
- cron: 0 * * * *
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
|
||||
3
.github/workflows/trunk.yml
vendored
3
.github/workflows/trunk.yml
vendored
@ -204,6 +204,7 @@ jobs:
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
@ -221,7 +222,7 @@ jobs:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
|
||||
secrets: inherit
|
||||
|
||||
inductor-build:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -127,6 +127,7 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
@ -398,3 +399,4 @@ CLAUDE.local.md
|
||||
/test_*.py
|
||||
/debug_*.py
|
||||
CLAUDE_CONTEXT/
|
||||
/.claude/settings.local.json
|
||||
|
||||
@ -121,94 +121,6 @@ command = [
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
[[linter]]
|
||||
code = 'MYPY'
|
||||
include_patterns = [
|
||||
'setup.py',
|
||||
'functorch/dim/**/*.py',
|
||||
'torch/**/*.py',
|
||||
'torch/**/*.pyi',
|
||||
'caffe2/**/*.py',
|
||||
'caffe2/**/*.pyi',
|
||||
'test/test_bundled_images.py',
|
||||
'test/test_bundled_inputs.py',
|
||||
'test/test_complex.py',
|
||||
'test/test_datapipe.py',
|
||||
'test/test_futures.py',
|
||||
'test/test_numpy_interop.py',
|
||||
'test/test_torch.py',
|
||||
'test/test_type_hints.py',
|
||||
'test/test_type_info.py',
|
||||
'test/test_utils.py',
|
||||
]
|
||||
exclude_patterns = [
|
||||
'**/fb/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/mypy_linter.py',
|
||||
'--config=mypy.ini',
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'expecttest==0.3.0',
|
||||
'mypy==1.16.0',
|
||||
'sympy==1.13.3',
|
||||
'types-requests==2.27.25',
|
||||
'types-pyyaml==6.0.2',
|
||||
'types-tabulate==0.8.8',
|
||||
'types-protobuf==5.29.1.20250403',
|
||||
'types-setuptools==79.0.0.20250422',
|
||||
'types-jinja2==2.11.9',
|
||||
'types-colorama==0.4.6',
|
||||
'filelock==3.18.0',
|
||||
'junitparser==2.1.1',
|
||||
'rich==14.1.0',
|
||||
'pyyaml==6.0.2',
|
||||
'optree==0.13.0',
|
||||
'dataclasses-json==0.6.7',
|
||||
'pandas==2.2.3',
|
||||
]
|
||||
|
||||
[[linter]]
|
||||
code = 'MYPYSTRICT'
|
||||
include_patterns = [
|
||||
'.github/**/*.py',
|
||||
'benchmarks/instruction_counts/**/*.py',
|
||||
'tools/**/*.py',
|
||||
'torchgen/**/*.py',
|
||||
'torch/utils/_pytree.py',
|
||||
'torch/utils/_cxx_pytree.py',
|
||||
'torch/utils/benchmark/utils/common.py',
|
||||
'torch/utils/benchmark/utils/timer.py',
|
||||
'torch/utils/benchmark/utils/valgrind_wrapper/**/*.py',
|
||||
]
|
||||
exclude_patterns = [
|
||||
# (linbinyu) copied from internal repo
|
||||
'**/fb/**',
|
||||
'tools/code_analyzer/gen_operators_yaml.py',
|
||||
'tools/dynamo/verify_dynamo.py',
|
||||
'tools/gen_vulkan_spv.py',
|
||||
'tools/test/gen_operators_yaml_test.py',
|
||||
'tools/test/gen_oplist_test.py',
|
||||
'tools/test/test_selective_build.py',
|
||||
'tools/experimental/torchfuzz/**',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/mypy_linter.py',
|
||||
'--config=mypy-strict.ini',
|
||||
'--code=MYPYSTRICT',
|
||||
'--',
|
||||
'@{{PATHSFILE}}'
|
||||
]
|
||||
|
||||
|
||||
[[linter]]
|
||||
code = 'PYREFLY'
|
||||
@ -230,6 +142,7 @@ init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'expecttest==0.3.0',
|
||||
'pyrefly==0.36.2',
|
||||
@ -298,7 +211,6 @@ exclude_patterns = [
|
||||
'**/*pb.h',
|
||||
'**/*inl.h',
|
||||
'aten/src/ATen/cpu/FlushDenormal.cpp',
|
||||
'aten/src/ATen/cpu/Utils.cpp',
|
||||
'aten/src/ATen/cpu/vml.h',
|
||||
'aten/src/ATen/CPUFixedAllocator.h',
|
||||
'aten/src/ATen/Parallel*.h',
|
||||
@ -317,8 +229,6 @@ exclude_patterns = [
|
||||
'c10/util/win32-headers.h',
|
||||
'c10/test/**/*.h',
|
||||
'third_party/**/*',
|
||||
'torch/csrc/api/include/torch/nn/modules/common.h',
|
||||
'torch/csrc/api/include/torch/linalg.h',
|
||||
'torch/csrc/autograd/generated/**',
|
||||
'torch/csrc/distributed/**/*.cu',
|
||||
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
|
||||
@ -330,7 +240,6 @@ exclude_patterns = [
|
||||
'torch/csrc/utils/generated_serialization_types.h',
|
||||
'torch/csrc/utils/pythoncapi_compat.h',
|
||||
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
|
||||
'aten/src/ATen/ExpandBase.h',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
|
||||
@ -11,7 +11,6 @@ aspects of contributing to PyTorch.
|
||||
<!-- toc -->
|
||||
|
||||
- [Developing PyTorch](#developing-pytorch)
|
||||
- [Setup the development environment](#setup-the-development-environment)
|
||||
- [Tips and Debugging](#tips-and-debugging)
|
||||
- [Nightly Checkout & Pull](#nightly-checkout--pull)
|
||||
- [Codebase structure](#codebase-structure)
|
||||
@ -67,23 +66,6 @@ aspects of contributing to PyTorch.
|
||||
|
||||
Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions.
|
||||
|
||||
### Setup the development environment
|
||||
|
||||
First, you need to [fork the PyTorch project on GitHub](https://github.com/pytorch/pytorch/fork) and follow the instructions at [Connecting to GitHub with SSH](https://docs.github.com/en/authentication/connecting-to-github-with-ssh) to setup your SSH authentication credentials.
|
||||
|
||||
Then clone the PyTorch project and setup the development environment:
|
||||
|
||||
```bash
|
||||
git clone git@github.com:<USERNAME>/pytorch.git
|
||||
cd pytorch
|
||||
git remote add upstream git@github.com:pytorch/pytorch.git
|
||||
|
||||
make setup-env
|
||||
# Or run `make setup-env-cuda` for pre-built CUDA binaries
|
||||
# Or run `make setup-env-rocm` for pre-built ROCm binaries
|
||||
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||
```
|
||||
|
||||
### Tips and Debugging
|
||||
|
||||
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.
|
||||
|
||||
20
SECURITY.md
20
SECURITY.md
@ -1,7 +1,7 @@
|
||||
# Security Policy
|
||||
|
||||
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
||||
- [**Using Pytorch Securely**](#using-pytorch-securely)
|
||||
- [**Using PyTorch Securely**](#using-pytorch-securely)
|
||||
- [Untrusted models](#untrusted-models)
|
||||
- [TorchScript models](#torchscript-models)
|
||||
- [Untrusted inputs](#untrusted-inputs)
|
||||
@ -10,28 +10,28 @@
|
||||
- [**CI/CD security principles**](#cicd-security-principles)
|
||||
## Reporting Security Issues
|
||||
|
||||
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
|
||||
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
|
||||
|
||||
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||
|
||||
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
|
||||
|
||||
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
|
||||
## Using Pytorch Securely
|
||||
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
## Using PyTorch Securely
|
||||
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||
|
||||
### Untrusted models
|
||||
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
|
||||
|
||||
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
|
||||
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||
|
||||
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
|
||||
|
||||
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
|
||||
|
||||
### TorchScript models
|
||||
|
||||
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||
|
||||
### Untrusted inputs during training and prediction
|
||||
|
||||
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
|
||||
|
||||
### Data privacy
|
||||
|
||||
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
|
||||
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
|
||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
|
||||
|
||||
### Using distributed features
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
||||
if(USE_CUDA)
|
||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||
|
||||
@ -181,7 +181,7 @@ c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
|
||||
static const size_t size = sizeof(CPUGeneratorImplState);
|
||||
static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType");
|
||||
|
||||
auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||
auto state_tensor = at::detail::empty_cpu({static_cast<int64_t>(size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||
auto rng_state = state_tensor.data_ptr();
|
||||
|
||||
// accumulate generator data to be copied into byte tensor
|
||||
|
||||
@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP()
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
These const variables defined the fp32 precisions for different backend
|
||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||
@ -41,16 +39,6 @@ namespace {
|
||||
->rnn
|
||||
*/
|
||||
|
||||
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||
TORCH_WARN_ONCE(
|
||||
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||
);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Float32Backend str2backend(const std::string& name) {
|
||||
if (name == "generic")
|
||||
return Float32Backend::GENERIC;
|
||||
@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
||||
} else {
|
||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||
}
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_cudnn;
|
||||
}
|
||||
|
||||
@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) {
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||
allow_tf32_cudnn = b;
|
||||
warn_deprecated_fp32_precision_api();
|
||||
}
|
||||
|
||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
@ -223,7 +209,7 @@ void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||
"setSDPPriority order expected ", sdp_priority_order.size() - 1, " but got ",
|
||||
at::num_sdp_backends, " unique backends specified in priority order.");
|
||||
for (uint32_t i = 0; i < order.size(); i++) {
|
||||
sdp_priority_order[i] = (at::SDPBackend) order[i];
|
||||
sdp_priority_order[i] = static_cast<at::SDPBackend>(order[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return allow_tf32_new;
|
||||
}
|
||||
|
||||
@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||
"We suggest only using the new API for matmul precision. See also: ",
|
||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||
warn_deprecated_fp32_precision_api();
|
||||
return float32_matmul_precision;
|
||||
}
|
||||
|
||||
@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
||||
|
||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||
auto match = [this](const std::string & s_) {
|
||||
warn_deprecated_fp32_precision_api();
|
||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||
if (s_ == "highest") {
|
||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||
|
||||
@ -197,6 +197,7 @@ inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
@ -208,6 +209,7 @@ inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
toString(_st), \
|
||||
"'"); \
|
||||
} \
|
||||
C10_DIAGNOSTIC_POP() \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
|
||||
@ -252,13 +252,13 @@ MapAllocator::MapAllocator(WithFd /*unused*/, std::string_view filename, int fd,
|
||||
if (!(flags_ & ALLOCATOR_MAPPED_FROMFD)) {
|
||||
if (flags_ & ALLOCATOR_MAPPED_SHARED) {
|
||||
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
||||
if ((fd = open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
|
||||
if ((fd = open(filename_.c_str(), flags, static_cast<mode_t>(0600))) == -1) {
|
||||
TORCH_CHECK(false, "unable to open file <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
|
||||
}
|
||||
} else if (flags_ & ALLOCATOR_MAPPED_SHAREDMEM) {
|
||||
#ifdef HAVE_SHM_OPEN
|
||||
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
||||
if((fd = shm_open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
|
||||
if((fd = shm_open(filename_.c_str(), flags, static_cast<mode_t>(0600))) == -1) {
|
||||
TORCH_CHECK(false, "unable to open shared memory object <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
|
||||
}
|
||||
#else
|
||||
@ -503,7 +503,7 @@ RefcountedMapAllocator::RefcountedMapAllocator(WithFd /*unused*/, const char *fi
|
||||
|
||||
void RefcountedMapAllocator::initializeAlloc() {
|
||||
TORCH_CHECK(base_ptr_, "base_ptr_ is null");
|
||||
MapInfo *map_info = (MapInfo*)base_ptr_;
|
||||
MapInfo *map_info = static_cast<MapInfo*>(base_ptr_);
|
||||
|
||||
#ifdef _WIN32
|
||||
ReleaseContext* r_ctx = new ReleaseContext;
|
||||
@ -539,7 +539,7 @@ void RefcountedMapAllocator::close() {
|
||||
}
|
||||
#else /* _WIN32 */
|
||||
|
||||
MapInfo *info = (MapInfo*)(data);
|
||||
MapInfo *info = static_cast<MapInfo*>(data);
|
||||
if (--info->refcount == 0) {
|
||||
#ifdef HAVE_SHM_UNLINK
|
||||
if (shm_unlink(filename_.c_str()) == -1) {
|
||||
|
||||
@ -862,7 +862,7 @@ void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
|
||||
shape_[dim] = size;
|
||||
view_offsets_[dim] += start;
|
||||
for (auto& op : operands_) {
|
||||
op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
|
||||
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[dim] * start;
|
||||
}
|
||||
if (size == 1 && !is_reduction_) {
|
||||
coalesce_dimensions();
|
||||
@ -873,7 +873,7 @@ void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indic
|
||||
TORCH_INTERNAL_ASSERT(start_dim <= ndim());
|
||||
for (const auto i : c10::irange(start_dim, ndim())) {
|
||||
for (auto& op : operands_) {
|
||||
op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
|
||||
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[i] * indices[i - start_dim];
|
||||
}
|
||||
shape_[i] = 1;
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ inline void serial_for_each(
|
||||
IntArrayRef strides,
|
||||
char** base_ptrs,
|
||||
size_t ntensors,
|
||||
typename TensorIteratorBase::loop2d_t loop,
|
||||
TensorIteratorBase::loop2d_t loop,
|
||||
Range range) {
|
||||
const auto ndim = shape.size();
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
|
||||
@ -72,10 +72,16 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) {
|
||||
m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>);
|
||||
|
||||
m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
m.impl("rand_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
m.impl("randn_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
|
||||
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
m.impl("randint_like.Tensor", unsupportedRandomOp<const Tensor&, const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
m.impl("randint_like.generator", unsupportedRandomOp<const Tensor&, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
m.impl("randint_like.Tensor_generator", unsupportedRandomOp<const Tensor&, const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
m.impl("randint_like.low_generator_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||
|
||||
m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
|
||||
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
|
||||
|
||||
@ -190,12 +190,14 @@ class IListRef;
|
||||
* it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
|
||||
*/
|
||||
#define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
||||
switch (TAG) { \
|
||||
TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
|
||||
break; \
|
||||
default: \
|
||||
TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \
|
||||
}
|
||||
} \
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
enum class IListRefTag {
|
||||
#define DEFINE_TAG(tag, ...) tag,
|
||||
|
||||
@ -56,7 +56,7 @@ C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
|
||||
* in this overloaded version
|
||||
*/
|
||||
template <typename T, typename V>
|
||||
C10_HOST_DEVICE inline std::enable_if_t<!(std::is_floating_point_v<T>), T>uniform_int(V val) {
|
||||
C10_HOST_DEVICE inline std::enable_if_t<!std::is_floating_point_v<T>, T>uniform_int(V val) {
|
||||
if constexpr (std::is_same_v<T, bool>) {
|
||||
return static_cast<bool>(val & 1);
|
||||
} else if constexpr (std::is_same_v<T, int64_t>) {
|
||||
|
||||
@ -114,25 +114,25 @@ inline typename remove_symint<T>::type unpackSymInt(T x) {
|
||||
}
|
||||
|
||||
template <>
|
||||
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
|
||||
inline remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
|
||||
return x.guard_int(__FILE__, __LINE__);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
|
||||
inline remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
|
||||
c10::SymIntArrayRef x) {
|
||||
return C10_AS_INTARRAYREF_SLOW(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline typename remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
|
||||
inline remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
|
||||
std::optional<c10::SymInt> x) {
|
||||
return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
|
||||
: std::nullopt;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
|
||||
inline remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
|
||||
at::OptionalSymIntArrayRef x) {
|
||||
return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
|
||||
: std::nullopt;
|
||||
|
||||
@ -631,8 +631,8 @@ call_functor_with_args_from_stack_(
|
||||
Stack* stack,
|
||||
std::index_sequence<ivalue_arg_indices...> /*unused*/,
|
||||
guts::typelist::typelist<ArgTypes...>* /*unused*/) {
|
||||
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
|
||||
// be unused and we have to silence the compiler warning.
|
||||
(void)stack; // when sizeof...(ivalue_arg_indices) == 0, this argument would
|
||||
// be unused and we have to silence the compiler warning.
|
||||
|
||||
// We're explicitly filtering out DispatchKeySet from the argument list.
|
||||
// Some kernels take a DispatchKeySet as their first argument in order to
|
||||
|
||||
@ -18,6 +18,7 @@ struct TORCH_API EnumType : public NamedType {
|
||||
TypePtr value,
|
||||
std::vector<EnumNameValue> enum_names_values,
|
||||
std::weak_ptr<::torch::jit::CompilationUnit> cu) {
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
|
||||
switch (value->kind()) {
|
||||
case TypeKind::IntType:
|
||||
case TypeKind::FloatType:
|
||||
@ -34,6 +35,7 @@ struct TORCH_API EnumType : public NamedType {
|
||||
value->str(),
|
||||
"', only int, float and string are supported");
|
||||
}
|
||||
C10_DIAGNOSTIC_POP()
|
||||
}
|
||||
|
||||
std::string str() const override {
|
||||
|
||||
@ -601,8 +601,8 @@ std::ostream& IValue::repr(
|
||||
double d = v.toDouble();
|
||||
int c = std::fpclassify(d);
|
||||
if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
|
||||
int64_t i = int64_t(d);
|
||||
if (double(i) == d) {
|
||||
int64_t i = static_cast<int64_t>(d);
|
||||
if (static_cast<double>(i) == d) {
|
||||
// -0.0 (signed zero) needs to be parsed as -0.
|
||||
if (i == 0 && std::signbit(d)) {
|
||||
return out << "-" << i << ".";
|
||||
@ -799,8 +799,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
||||
double d = v.toDouble();
|
||||
int c = std::fpclassify(d);
|
||||
if (c == FP_NORMAL || c == FP_ZERO) {
|
||||
int64_t i = int64_t(d);
|
||||
if (double(i) == d) {
|
||||
int64_t i = static_cast<int64_t>(d);
|
||||
if (static_cast<double>(i) == d) {
|
||||
return out << i << ".";
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
|
||||
inline bool is_contiguous_strides(
|
||||
const IntArrayRef sizes,
|
||||
const IntArrayRef strides) {
|
||||
int n_dim = static_cast<int>(sizes.size());
|
||||
size_t n_dim = sizes.size();
|
||||
if (n_dim == 0) {
|
||||
return true;
|
||||
}
|
||||
@ -50,7 +50,7 @@ inline bool is_contiguous_strides(
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = n_dim - 2; i >= 0; i--) {
|
||||
for (int i = static_cast<int>(n_dim) - 2; i >= 0; i--) {
|
||||
if (strides[i] != strides[i + 1] * sizes[i + 1]) {
|
||||
return false;
|
||||
}
|
||||
@ -922,6 +922,7 @@ struct TORCH_API DictType : public SharedType {
|
||||
if (auto dyn = key->castRaw<DynamicType>()) {
|
||||
kind = dyn->dynamicKind();
|
||||
}
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
|
||||
switch (kind) {
|
||||
case TypeKind::AnyType:
|
||||
case TypeKind::IntType:
|
||||
@ -938,6 +939,7 @@ struct TORCH_API DictType : public SharedType {
|
||||
key->str(),
|
||||
"', only int, float, complex, Tensor, device and string keys are supported");
|
||||
}
|
||||
C10_DIAGNOSTIC_POP()
|
||||
}
|
||||
|
||||
// aligned with the format in FunctionSchema
|
||||
@ -2371,7 +2373,7 @@ private:
|
||||
};
|
||||
|
||||
template<>
|
||||
inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
|
||||
inline detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
|
||||
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
|
||||
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
|
||||
return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this());
|
||||
@ -2380,7 +2382,7 @@ inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>()
|
||||
}
|
||||
|
||||
template<>
|
||||
inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
|
||||
inline detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
|
||||
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
|
||||
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
|
||||
return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this());
|
||||
|
||||
@ -191,22 +191,37 @@ inline void convert(const at::Half* src, bool* dst, int64_t n) {
|
||||
}
|
||||
|
||||
#endif
|
||||
#ifdef __ARM_FEATURE_BF16
|
||||
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||
|
||||
template <typename to_type>
|
||||
inline void convertFromBf16Impl(
|
||||
const c10::BFloat16* __restrict src,
|
||||
to_type* __restrict dst,
|
||||
int64_t n) {
|
||||
const uint16_t* srcPtr = reinterpret_cast<const uint16_t*>(src);
|
||||
uint64_t len = static_cast<uint64_t>(n);
|
||||
for (uint64_t i = 0; i < len; i++) {
|
||||
uint32_t tmp = static_cast<uint32_t>(srcPtr[i]) << 16;
|
||||
float tmpF;
|
||||
__builtin_memcpy(&tmpF, &tmp, sizeof(float));
|
||||
dst[i] = static_cast<to_type>(tmpF);
|
||||
}
|
||||
}
|
||||
#define CONVERT_FROM_BF16_TEMPLATE(to_type) \
|
||||
template <> \
|
||||
inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \
|
||||
return convertFromBf16Impl<to_type>(src, dst, n); \
|
||||
}
|
||||
|
||||
CONVERT_FROM_BF16_TEMPLATE(uint8_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int8_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int16_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int32_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(int64_t)
|
||||
CONVERT_FROM_BF16_TEMPLATE(float)
|
||||
CONVERT_FROM_BF16_TEMPLATE(double)
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
||||
#endif
|
||||
|
||||
inline void convertBoolToBfloat16Impl(
|
||||
const bool* __restrict src,
|
||||
@ -247,8 +262,6 @@ inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) {
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
template <typename src_t>
|
||||
struct VecConvert<
|
||||
float,
|
||||
|
||||
@ -514,7 +514,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||
|
||||
using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
|
||||
using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
|
||||
using value_type = typename c10::qint8::underlying;
|
||||
using value_type = c10::qint8::underlying;
|
||||
|
||||
public:
|
||||
using Vectorizedqi::Vectorizedqi;
|
||||
@ -727,7 +727,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||
|
||||
using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
|
||||
using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
|
||||
using value_type = typename c10::quint8::underlying;
|
||||
using value_type = c10::quint8::underlying;
|
||||
|
||||
public:
|
||||
using Vectorizedqi::Vectorizedqi;
|
||||
|
||||
@ -567,7 +567,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||
|
||||
using float_vec_return_type = std::array<Vectorized<float>, 4>;
|
||||
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
|
||||
using value_type = typename c10::qint8::underlying;
|
||||
using value_type = c10::qint8::underlying;
|
||||
|
||||
public:
|
||||
using Vectorizedqi::Vectorizedqi;
|
||||
@ -804,7 +804,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||
|
||||
using float_vec_return_type = std::array<Vectorized<float>, 4>;
|
||||
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
|
||||
using value_type = typename c10::quint8::underlying;
|
||||
using value_type = c10::quint8::underlying;
|
||||
|
||||
public:
|
||||
using Vectorizedqi::Vectorizedqi;
|
||||
|
||||
@ -672,7 +672,7 @@ struct Vectorized {
|
||||
return map(std::sqrt);
|
||||
}
|
||||
Vectorized<T> reciprocal() const {
|
||||
return map([](T x) { return (T)(1) / x; });
|
||||
return map([](T x) { return (T)1 / x; });
|
||||
}
|
||||
Vectorized<T> rsqrt() const {
|
||||
return map([](T x) { return (T)1 / std::sqrt(x); });
|
||||
|
||||
@ -46,7 +46,7 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
|
||||
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
|
||||
map(
|
||||
[](const Vectorized<scalar_t>& x) {
|
||||
return Vectorized<scalar_t>((scalar_t)(1)) / x.sqrt();
|
||||
return Vectorized<scalar_t>((scalar_t)1) / x.sqrt();
|
||||
},
|
||||
out + begin,
|
||||
in + begin,
|
||||
|
||||
@ -194,8 +194,8 @@ void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
|
||||
void CUDAGeneratorState::capture_prologue() {
|
||||
capturing_ = true;
|
||||
offset_intragraph_ = 0;
|
||||
seed_extragraph_.fill_(int64_t(seed_));
|
||||
offset_extragraph_.fill_(int64_t(0));
|
||||
seed_extragraph_.fill_(static_cast<int64_t>(seed_));
|
||||
offset_extragraph_.fill_(0);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -216,8 +216,8 @@ void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) {
|
||||
at::cuda::assertNotCapturing(
|
||||
"Cannot prepare for replay during capturing stage.");
|
||||
if (wholegraph_increment) {
|
||||
seed_extragraph_.fill_(int64_t(seed_));
|
||||
offset_extragraph_.fill_(int64_t(philox_offset_per_thread_));
|
||||
seed_extragraph_.fill_(static_cast<int64_t>(seed_));
|
||||
offset_extragraph_.fill_(static_cast<int64_t>(philox_offset_per_thread_));
|
||||
// Applies the total increment achieved during previous captures to update the
|
||||
// offset.
|
||||
increase(wholegraph_increment);
|
||||
@ -329,7 +329,7 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
||||
constexpr size_t offset_size = sizeof(int64_t);
|
||||
constexpr size_t total_size = seed_size + offset_size;
|
||||
|
||||
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||
auto state_tensor = at::detail::empty_cpu({static_cast<int64_t>(total_size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
||||
auto current_seed = this->current_seed();
|
||||
auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
@ -155,8 +155,8 @@ size_t parseChosenWorkspaceSize() {
|
||||
while (next != end) {
|
||||
std::smatch match = *next;
|
||||
TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
|
||||
size_t curr_size = (size_t) std::stoi(match.str(1));
|
||||
size_t count = (size_t) std::stoi(match.str(2));
|
||||
size_t curr_size = std::stoull(match.str(1));
|
||||
size_t count = std::stoull(match.str(2));
|
||||
total_size += curr_size * 1024 * count;
|
||||
next++;
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
@ -136,9 +137,9 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
|
||||
"Weight strides: ", t.strides(), "\n",
|
||||
"cuDNN suggested memory_format: ", memory_format);
|
||||
|
||||
int size[CUDNN_DIM_MAX];
|
||||
std::array<int, CUDNN_DIM_MAX> size;
|
||||
for (const auto i : c10::irange(dim)) {
|
||||
size[i] = (int) t.size(i);
|
||||
size[i] = static_cast<int>(t.size(i));
|
||||
}
|
||||
for (const auto i : c10::irange(dim, pad)) {
|
||||
size[i] = 1;
|
||||
@ -156,7 +157,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters");
|
||||
}
|
||||
set(getDataType(t), static_cast<int>(dim), size, filter_format);
|
||||
set(getDataType(t), static_cast<int>(dim), size.data(), filter_format);
|
||||
}
|
||||
|
||||
std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
|
||||
|
||||
@ -9,8 +9,8 @@
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
|
||||
#include <c10/util/python_stub.h>
|
||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
|
||||
#include <string>
|
||||
namespace at {
|
||||
@ -26,8 +26,7 @@ constexpr const char* MTIA_HELP =
|
||||
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
// this fails the implementation if MTIAHooks functions are called, but
|
||||
// MTIA backend is not present.
|
||||
#define FAIL_MTIAHOOKS_FUNC(func) \
|
||||
TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
|
||||
#define FAIL_MTIAHOOKS_FUNC(func) TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
|
||||
|
||||
~MTIAHooksInterface() override = default;
|
||||
|
||||
@ -92,7 +91,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
||||
}
|
||||
|
||||
virtual void setCurrentStream(const c10::Stream& /*stream*/ ) const {
|
||||
virtual void setCurrentStream(const c10::Stream& /*stream*/) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
@ -124,11 +123,9 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
|
||||
virtual void recordMemoryHistory(
|
||||
const std::optional<std::string>& /*enabled*/,
|
||||
const std::string& /*stacks*/,
|
||||
size_t /*max_entries*/) const {
|
||||
virtual void recordMemoryHistory(const std::optional<std::string>& /*enabled*/,
|
||||
const std::string& /*stacks*/,
|
||||
size_t /*max_entries*/) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
@ -159,6 +156,10 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual void mtiagraphDestroy(int64_t handle) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
|
||||
virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
}
|
||||
@ -187,8 +188,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
struct TORCH_API MTIAHooksArgs {};
|
||||
|
||||
TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
|
||||
#define REGISTER_MTIA_HOOKS(clsname) \
|
||||
C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
|
||||
#define REGISTER_MTIA_HOOKS(clsname) C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
|
||||
|
||||
namespace detail {
|
||||
TORCH_API const MTIAHooksInterface& getMTIAHooks();
|
||||
|
||||
@ -198,7 +198,7 @@ static void autogradBasedTransformSendToNext(
|
||||
}
|
||||
|
||||
// Step 6
|
||||
stack->erase(stack->end() - std::ptrdiff_t(args_size + ret_size), stack->end() - std::ptrdiff_t(ret_size));
|
||||
stack->erase(stack->end() - static_cast<std::ptrdiff_t>(args_size + ret_size), stack->end() - static_cast<std::ptrdiff_t>(ret_size));
|
||||
}
|
||||
|
||||
void GradInterpreterPtr::processImpl(
|
||||
|
||||
@ -443,14 +443,14 @@ static bool has_same_shape(
|
||||
if (!tensor.defined()) {
|
||||
return true;
|
||||
}
|
||||
if (rankWithoutBatchDim(tensor, tensor_bdim) != (int64_t) normalized_shape.size()) {
|
||||
if (rankWithoutBatchDim(tensor, tensor_bdim) != static_cast<int64_t>(normalized_shape.size())) {
|
||||
return false;
|
||||
}
|
||||
const auto tensor_shape = tensor.sizes();
|
||||
for (const auto i : c10::irange(normalized_shape.size())) {
|
||||
auto j = i;
|
||||
// (0, 1, 2), 1 -> (0, 2, 3)
|
||||
if (tensor_bdim.has_value() && (int64_t)i >= tensor_bdim.value()) {
|
||||
if (tensor_bdim.has_value() && static_cast<int64_t>(i) >= tensor_bdim.value()) {
|
||||
j = j + 1;
|
||||
}
|
||||
if (normalized_shape[i] != tensor_shape[j]) {
|
||||
|
||||
@ -135,7 +135,7 @@ static void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit
|
||||
reduction_case = ReductionCase::DimArray;
|
||||
dims = arguments[dim_arg_pos].toIntList().vec();
|
||||
if (dims.empty()) {
|
||||
auto all_dims = range(0, std::max((int64_t)1, logical_dim));
|
||||
auto all_dims = range(0, std::max(static_cast<int64_t>(1), logical_dim));
|
||||
dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
|
||||
}
|
||||
} else if (arguments[dim_arg_pos].isInt()) {
|
||||
|
||||
@ -432,7 +432,7 @@ namespace {
|
||||
// Eg. Given `indexed_shape.size()` is 5 and
|
||||
// shape of `values` is (N, 2, 3), then following block
|
||||
// will reshape `values` to (N, 1, 1, 2, 3).
|
||||
if ( (int64_t) indexed_shape.size() > values_.dim()) {
|
||||
if ( static_cast<int64_t>(indexed_shape.size()) > values_.dim()) {
|
||||
auto values_sizes = values_.sym_sizes();
|
||||
|
||||
// number of unit dims (for broadcasting value to indexed_shape)
|
||||
|
||||
@ -109,7 +109,7 @@ std::tuple<Tensor, std::optional<int64_t>> repeat_batch_rule(
|
||||
SymDimVector sizes_with_bdim = { sizes.begin(), sizes.end() };
|
||||
sizes_with_bdim.insert(sizes_with_bdim.begin(), 1);
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
while (self_.dim() < (int64_t)sizes_with_bdim.size()) {
|
||||
while (self_.dim() < static_cast<int64_t>(sizes_with_bdim.size())) {
|
||||
self_ = self_.unsqueeze(1);
|
||||
}
|
||||
return std::make_tuple(self_.repeat_symint(sizes_with_bdim), 0);
|
||||
|
||||
@ -191,7 +191,7 @@ static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, t
|
||||
// simplicity. When that is not the case, this code should be updated.
|
||||
const auto& argument = (*stack)[arguments_begin + arg_idx];
|
||||
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|
||||
|| (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
|
||||
|| static_cast<int64_t>(arg_idx) != *batched_tensor_inputs_pos_iter) {
|
||||
// argument isn't a BatchedTensor
|
||||
torch::jit::push(stack, argument);
|
||||
continue;
|
||||
@ -345,7 +345,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
|
||||
// simplicity. When that is not the case, this code should be updated.
|
||||
const auto& argument = (*stack)[arguments_begin + arg_idx];
|
||||
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|
||||
|| (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
|
||||
|| static_cast<int64_t>(arg_idx) != *batched_tensor_inputs_pos_iter) {
|
||||
// argument isn't a BatchedTensor
|
||||
torch::jit::push(stack, argument);
|
||||
continue;
|
||||
@ -473,7 +473,7 @@ void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::ji
|
||||
// simplicity. When that is not the case, this code should be updated.
|
||||
const auto& argument = (*stack)[arguments_begin + arg_idx];
|
||||
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|
||||
|| (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
|
||||
|| static_cast<int64_t>(arg_idx) != *batched_tensor_inputs_pos_iter) {
|
||||
// argument isn't a BatchedTensor
|
||||
torch::jit::push(stack, argument);
|
||||
continue;
|
||||
|
||||
@ -157,7 +157,7 @@ Tensor& squeeze__batching_rule(Tensor& self) {
|
||||
const auto physical_shape = batched->value().sizes();
|
||||
auto how_many_dims_of_size_1_before_bdim = 0;
|
||||
for (const auto i : c10::irange(0, physical_shape.size())) {
|
||||
if ((int64_t)i == bdim) {
|
||||
if (static_cast<int64_t>(i) == bdim) {
|
||||
break;
|
||||
}
|
||||
if (physical_shape[i] == 1) {
|
||||
@ -573,7 +573,7 @@ Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
|
||||
}
|
||||
|
||||
auto new_dim = bdim_size.has_value() ? dim + 1 : dim;
|
||||
std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : std::nullopt;
|
||||
std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional(static_cast<int64_t>(0)) : std::nullopt;
|
||||
auto result = at::cat(tensors_to_cat, new_dim);
|
||||
return makeBatched(result, new_bdim, get_current_level());
|
||||
}
|
||||
|
||||
@ -198,9 +198,9 @@ void avg_pool3d_out_frame(
|
||||
int64_t hend = std::min(hstart + kH, iheight + padH);
|
||||
int64_t wend = std::min(wstart + kW, iwidth + padW);
|
||||
int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
|
||||
tstart = std::max(tstart, (int64_t) 0);
|
||||
hstart = std::max(hstart, (int64_t) 0);
|
||||
wstart = std::max(wstart, (int64_t) 0);
|
||||
tstart = std::max(tstart, static_cast<int64_t>(0));
|
||||
hstart = std::max(hstart, static_cast<int64_t>(0));
|
||||
wstart = std::max(wstart, static_cast<int64_t>(0));
|
||||
tend = std::min(tend, itime);
|
||||
hend = std::min(hend, iheight);
|
||||
wend = std::min(wend, iwidth);
|
||||
@ -377,9 +377,9 @@ void avg_pool3d_backward_out_frame(
|
||||
int64_t hend = std::min(hstart + kH, iheight + padH);
|
||||
int64_t wend = std::min(wstart + kW, iwidth + padW);
|
||||
int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
|
||||
tstart = std::max(tstart, (int64_t) 0);
|
||||
hstart = std::max(hstart, (int64_t) 0);
|
||||
wstart = std::max(wstart, (int64_t) 0);
|
||||
tstart = std::max(tstart, static_cast<int64_t>(0));
|
||||
hstart = std::max(hstart, static_cast<int64_t>(0));
|
||||
wstart = std::max(wstart, static_cast<int64_t>(0));
|
||||
tend = std::min(tend, itime);
|
||||
hend = std::min(hend, iheight);
|
||||
wend = std::min(wend, iwidth);
|
||||
|
||||
@ -2917,9 +2917,7 @@ static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, con
|
||||
DEFINE_DISPATCH(linalg_eig_stub);
|
||||
|
||||
static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Tensor& values, Tensor& vectors, Tensor& infos, bool compute_eigenvectors) {
|
||||
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
|
||||
// therefore we create all intermediate tensors on CPU
|
||||
auto options = input.options().device(at::kCPU);
|
||||
auto options = input.options();
|
||||
|
||||
// These internal asserts make explicit the assumptions in the implementation
|
||||
// Error check with the actual error messages are done on the higher level of the hierarchy of calls
|
||||
@ -2928,16 +2926,13 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
|
||||
|
||||
// for real-valued 'input', eigenvalues can be real-valued or complex-valued
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == values.scalar_type()) || (input.scalar_type() == values.scalar_type()));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
|
||||
|
||||
// for real-valued 'input', eigenvectors can be real-valued or complex-valued
|
||||
if (compute_eigenvectors) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == vectors.scalar_type()) || (input.scalar_type() == vectors.scalar_type()));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.device() == at::kCPU);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.scalar_type() == at::kInt);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.numel() == std::max<int64_t>(1, batchCount(input)));
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_contiguous());
|
||||
|
||||
@ -2986,15 +2981,7 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
|
||||
}
|
||||
}
|
||||
|
||||
// MAGMA uses a hybrid CPU-GPU algorithm that performs well only for large matrices
|
||||
// See: https://github.com/pytorch/pytorch/pull/52491#issuecomment-795685687
|
||||
// Here we call CPU path for matrices smaller than 2048x2048
|
||||
// that should be in general significantly faster than calling MAGMA
|
||||
if (input.size(-1) <= 2048) {
|
||||
linalg_eig_stub(at::kCPU, real_imag_values, maybe_complex_vectors, infos, input.to(kCPU), compute_eigenvectors);
|
||||
} else {
|
||||
linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
|
||||
}
|
||||
linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors);
|
||||
|
||||
// if input is not complex we need to do some post-processing
|
||||
if (!input.is_complex()) {
|
||||
@ -3019,7 +3006,14 @@ static std::tuple<Tensor&, Tensor&> linalg_eig_out_info(const Tensor& input, Ten
|
||||
}
|
||||
if (compute_eigenvectors) {
|
||||
if (vectors.is_complex()) {
|
||||
vectors = linalg_eig_make_complex_eigenvectors(vectors, values, maybe_complex_vectors);
|
||||
// We move to the CPU because linalg_eig_make_complex_eigenvectors requires it.
|
||||
// Performance note: this function could be implemented via a TensorIterator,
|
||||
// which would avoid an explicit host-device synchronization.
|
||||
auto vectors_cpu = vectors.cpu();
|
||||
auto values_cpu = values.cpu();
|
||||
auto maybe_complex_vectors_cpu = maybe_complex_vectors.cpu();
|
||||
vectors_cpu = linalg_eig_make_complex_eigenvectors(vectors_cpu, values_cpu, maybe_complex_vectors_cpu);
|
||||
vectors.copy_(vectors_cpu);
|
||||
} else {
|
||||
TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvectors is non-zero, can't safely cast eigenvectors to non-complex dtype.")
|
||||
}
|
||||
@ -3039,8 +3033,7 @@ std::tuple<Tensor&, Tensor&> linalg_eig_out(const Tensor& input, Tensor& values,
|
||||
checkSameDevice("torch.linalg.eig", values, input, "eigenvalues");
|
||||
checkSameDevice("torch.linalg.eig", vectors, input, "eigenvectors");
|
||||
|
||||
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
|
||||
auto options = input.options().device(at::kCPU);
|
||||
auto options = input.options();
|
||||
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
|
||||
|
||||
// if result is not empty and not in batched column major format we have to allocate a temporary tensor
|
||||
@ -3129,8 +3122,7 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
|
||||
checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues");
|
||||
checkSameDevice("torch.linalg.eigvals", values, input, "eigenvalues");
|
||||
|
||||
// MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU
|
||||
auto options = input.options().device(at::kCPU);
|
||||
auto options = input.options();
|
||||
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, options.dtype(kInt));
|
||||
|
||||
bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type()));
|
||||
@ -3159,6 +3151,7 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) {
|
||||
}
|
||||
|
||||
Tensor vectors;
|
||||
vectors = at::empty({0}, input.options());
|
||||
if (values_tmp_needed) {
|
||||
Tensor values_tmp = at::empty({0}, options.dtype(values_type));
|
||||
std::tie(values_tmp, std::ignore) = linalg_eig_out_info(input, values_tmp, vectors, infos, /*compute_eigenvectors=*/false);
|
||||
|
||||
@ -946,10 +946,10 @@ void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& in
|
||||
}
|
||||
};
|
||||
// avoid overflow
|
||||
float matrix_rank = float(std::min(m, n));
|
||||
auto matrix_rank = std::min(m, n);
|
||||
// A heuristic tested on a 32 core/socket ICX system
|
||||
// https://github.com/pytorch/pytorch/pull/93037#discussion_r1090112948
|
||||
int64_t chunk_size_per_thread = int64_t(
|
||||
int64_t chunk_size_per_thread = static_cast<int64_t>(
|
||||
std::min(1.0, 3200.0 / (matrix_rank * matrix_rank * matrix_rank)));
|
||||
int64_t grain_size = chunk_size_per_thread * at::get_num_threads();
|
||||
at::parallel_for(0, batch_size, grain_size, loop);
|
||||
|
||||
@ -267,7 +267,7 @@ _scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2,
|
||||
|
||||
float input_scale = scale_a.item<float>();
|
||||
float weight_scale = scale_b.item<float>();
|
||||
float output_scale = float(1.0);
|
||||
float output_scale = 1.0f;
|
||||
if (scale_result.has_value() &&
|
||||
(*out_dtype == ScalarType::Float8_e4m3fn ||
|
||||
*out_dtype == ScalarType::Float8_e5m2)) {
|
||||
|
||||
@ -331,7 +331,7 @@ bool gemv_use_fast_path<double>(
|
||||
[[maybe_unused]] double beta,
|
||||
int64_t incy) {
|
||||
return gemv_use_fast_path<float>(
|
||||
trans, m, n, (float)alpha, lda, incx, (float)beta, incy);
|
||||
trans, m, n, static_cast<float>(alpha), lda, incx, static_cast<float>(beta), incy);
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -523,8 +523,8 @@ static inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx)
|
||||
if (n == 1) incx = 1;
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if (blas_impl::scal_use_fast_path<scalar_t>(n, incx)) {
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
blas_impl::scal_fast_path<scalar_t>(&i_n, &a, x, &i_incx);
|
||||
return;
|
||||
}
|
||||
@ -545,11 +545,11 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if (blas_impl::gemv_use_fast_path<scalar_t>(trans, m, n, alpha, lda, incx, beta, incy)) {
|
||||
TORCH_CHECK(lda >= std::max<int64_t>(1L, m), "lda should be at least max(1,", m, "), but have ", lda);
|
||||
int i_m = (int)m;
|
||||
int i_n = (int)n;
|
||||
int i_lda = (int)lda;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_m = static_cast<int>(m);
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_lda = static_cast<int>(lda);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
blas_impl::gemv_fast_path<scalar_t>(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -680,9 +680,9 @@ void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
#if C10_IOS
|
||||
cblas_daxpy(i_n, a, x, i_incx, y, i_incy);
|
||||
#else
|
||||
@ -705,9 +705,9 @@ void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t in
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
#if C10_IOS
|
||||
cblas_saxpy(i_n, a, x, i_incx, y, i_incy);
|
||||
#else
|
||||
@ -730,9 +730,9 @@ void axpy(int64_t n, c10::complex<double> a, const c10::complex<double> *x, int6
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
#if C10_IOS
|
||||
cblas_zaxpy(i_n, &a, x, i_incx, y, i_incy);
|
||||
#else
|
||||
@ -755,9 +755,9 @@ void axpy(int64_t n, c10::complex<float> a, const c10::complex<float> *x, int64_
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
|
||||
{
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
#if C10_IOS
|
||||
cblas_caxpy(i_n, &a, x, i_incx, y, i_incy);
|
||||
#else
|
||||
@ -781,9 +781,9 @@ void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy) {
|
||||
}
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
#if C10_IOS
|
||||
cblas_dcopy(i_n, x, i_incx, y, i_incy);
|
||||
#else
|
||||
@ -805,9 +805,9 @@ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy) {
|
||||
}
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
#if C10_IOS
|
||||
cblas_scopy(i_n, x, i_incx, y, i_incy);
|
||||
#else
|
||||
@ -829,9 +829,9 @@ void copy(int64_t n, const c10::complex<double> *x, int64_t incx, c10::complex<d
|
||||
}
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
#if C10_IOS
|
||||
cblas_zcopy(i_n, x, i_incx, y, i_incy);
|
||||
#else
|
||||
@ -853,9 +853,9 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
|
||||
}
|
||||
#if AT_BUILD_WITH_BLAS()
|
||||
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) {
|
||||
int i_n = (int)n;
|
||||
int i_incx = (int)incx;
|
||||
int i_incy = (int)incy;
|
||||
int i_n = static_cast<int>(n);
|
||||
int i_incx = static_cast<int>(incx);
|
||||
int i_incy = static_cast<int>(incy);
|
||||
#if C10_IOS
|
||||
cblas_ccopy(i_n, &x, i_incx, y, i_incy);
|
||||
#else
|
||||
@ -1082,7 +1082,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
int64_t(1),
|
||||
1,
|
||||
ld_a,
|
||||
ld_b,
|
||||
ld_c,
|
||||
@ -1096,7 +1096,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
int64_t(1),
|
||||
1,
|
||||
ld_a,
|
||||
ld_b,
|
||||
ld_c,
|
||||
|
||||
@ -487,17 +487,17 @@ static Tensor _grid_sampler_2d_cpu_quantized(
|
||||
int64_t out_sC = output.stride(1);
|
||||
int64_t out_sH = output.stride(2);
|
||||
int64_t out_sW = output.stride(3);
|
||||
uint8_t* inp_ptr = (uint8_t*)input.data_ptr<quint8>();
|
||||
uint8_t* out_ptr = (uint8_t*)output.data_ptr<quint8>();
|
||||
float* grid_ptr = grid.data_ptr<float>();
|
||||
const uint8_t* inp_ptr = input.const_data_ptr<uint8_t>();
|
||||
uint8_t* out_ptr = output.data_ptr<uint8_t>();
|
||||
const float* grid_ptr = grid.const_data_ptr<float>();
|
||||
at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
|
||||
for (const auto n : c10::irange(start, end)) {
|
||||
float* grid_ptr_N = grid_ptr + n * grid_sN;
|
||||
uint8_t* inp_ptr_N = inp_ptr + n * inp_sN;
|
||||
const float* grid_ptr_N = grid_ptr + n * grid_sN;
|
||||
const uint8_t* inp_ptr_N = inp_ptr + n * inp_sN;
|
||||
for (const auto h : c10::irange(out_H)) {
|
||||
for (const auto w : c10::irange(out_W)) {
|
||||
// get the corresponding input x, y, z coordinates from grid
|
||||
float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
|
||||
const float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW;
|
||||
float x = *grid_ptr_NHW;
|
||||
float y = grid_ptr_NHW[grid_sCoor];
|
||||
|
||||
@ -527,7 +527,7 @@ static Tensor _grid_sampler_2d_cpu_quantized(
|
||||
float se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
// calculate bilinear weighted pixel value and set output pixel
|
||||
uint8_t* inp_ptr_NC = inp_ptr_N;
|
||||
const uint8_t* inp_ptr_NC = inp_ptr_N;
|
||||
uint8_t* out_ptr_NCHW =
|
||||
out_ptr + n * out_sN + h * out_sH + w * out_sW;
|
||||
for (int64_t c = 0; c < C;
|
||||
|
||||
@ -318,7 +318,7 @@ static std::vector<Tensor>& histogramdd_bin_edges_out(const Tensor& self, IntArr
|
||||
|
||||
const int64_t N = self.size(-1);
|
||||
const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1,
|
||||
(int64_t)1, std::multiplies<int64_t>());
|
||||
static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
Tensor reshaped_self = self.reshape({ M, N });
|
||||
|
||||
auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range);
|
||||
|
||||
@ -40,7 +40,7 @@ Tensor do_trapezoid(const Tensor& y, const Tensor& dx, int64_t dim) {
|
||||
// When dx is constant, the above formula simplifies
|
||||
// to dx * [(\sum_{i=1}^n y_i) - (y_1 + y_n)/2]
|
||||
Tensor do_trapezoid(const Tensor& y, double dx, int64_t dim) {
|
||||
return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * (0.5)) * dx;
|
||||
return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * 0.5) * dx;
|
||||
}
|
||||
|
||||
Tensor zeros_like_except(const Tensor& y, int64_t dim) {
|
||||
|
||||
@ -201,7 +201,7 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra
|
||||
out_size.reserve(out_num_dim);
|
||||
for (auto& d : lro) out_size.push_back(left.sym_size(d));
|
||||
for (auto& d : lo) out_size.push_back(left.sym_size(d));
|
||||
for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)(d); }; // avoid warning about not using d
|
||||
for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)d; }; // avoid warning about not using d
|
||||
for (auto& d : ro) out_size.push_back(right.sym_size(d));
|
||||
|
||||
std::vector<int64_t> lpermutation(lro);
|
||||
@ -640,7 +640,7 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr
|
||||
}
|
||||
}
|
||||
|
||||
return ops[0];
|
||||
return std::move(ops[0]);
|
||||
}
|
||||
|
||||
// _trilinear computes a trilinear einstein sum with an unrolled dimension
|
||||
@ -805,7 +805,7 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1,
|
||||
std::vector<SymInt> rsizes; // rsizes: sizes of the result
|
||||
p1.reserve(input1.dim());
|
||||
p2.reserve(input2.dim());
|
||||
rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size());
|
||||
rsizes.reserve(input1.dim() + input2.dim() - static_cast<int64_t>(dims1.size()));
|
||||
SymInt size1 = 1; // number of non-contracted elements in input1
|
||||
SymInt size2 = 1; // number of non-contracted elements in input2
|
||||
|
||||
|
||||
@ -1655,7 +1655,7 @@ static inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self,
|
||||
auto s0 = self.accessor<const scalar_t, 3>();
|
||||
auto m0 = mat2.accessor<const scalar_t, 3>();
|
||||
|
||||
int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1);
|
||||
int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), static_cast<int64_t>(1));
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) {
|
||||
for (const auto b : c10::irange(b_begin, b_end)) {
|
||||
|
||||
@ -235,7 +235,7 @@ void nll_loss_out_frame(
|
||||
|
||||
constexpr int64_t cascade_sum_num_levels = 8;
|
||||
const int64_t level_power =
|
||||
std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
|
||||
std::max(static_cast<int64_t>(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
|
||||
const int64_t level_step = (1 << level_power);
|
||||
const int64_t level_mask = level_step - 1;
|
||||
|
||||
|
||||
@ -129,7 +129,7 @@ void nll_loss2d_forward_out_frame(
|
||||
for (const auto b : c10::irange(start, end)) {
|
||||
for (const auto h : c10::irange(H)) {
|
||||
for (const auto w : c10::irange(W)) {
|
||||
const int64_t cur_target = (int64_t)target_acc[b][h][w];
|
||||
const int64_t cur_target = target_acc[b][h][w];
|
||||
|
||||
if (cur_target == ignore_index) {
|
||||
output_acc[b][h][w] = static_cast<scalar_t>(0);
|
||||
@ -188,7 +188,7 @@ void nll_loss2d_forward_out_frame(
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
|
||||
const int64_t level_power =
|
||||
std::max(int64_t(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
|
||||
std::max(static_cast<int64_t>(4), utils::CeilLog2(numiter) / cascade_sum_num_levels);
|
||||
const int64_t level_step = (1 << level_power);
|
||||
const int64_t level_mask = level_step - 1;
|
||||
|
||||
|
||||
@ -192,7 +192,7 @@ Date: February 1996
|
||||
x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)/static_cast<T>(std::sqrt(c10::pi<double>)))*std::exp(-x*x));
|
||||
x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)/static_cast<T>(std::sqrt(c10::pi<double>)))*std::exp(-x*x));
|
||||
|
||||
return(x);
|
||||
return x;
|
||||
}
|
||||
|
||||
#undef CENTRAL_RANGE
|
||||
@ -3819,7 +3819,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
|
||||
|
||||
if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) {
|
||||
if (std::sin(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) {
|
||||
return std::cos(((n) + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0));
|
||||
return std::cos((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0));
|
||||
}
|
||||
|
||||
if (n % 2 == 0) {
|
||||
|
||||
@ -193,22 +193,22 @@ Tensor _nnpack_spatial_convolution(
|
||||
const size_t input_channels = input.size(1);
|
||||
const size_t output_channels = weight.size(0);
|
||||
const struct nnp_size input_size = {
|
||||
.width = (size_t)input.size(3),
|
||||
.height = (size_t)input.size(2),
|
||||
.width = static_cast<size_t>(input.size(3)),
|
||||
.height = static_cast<size_t>(input.size(2)),
|
||||
};
|
||||
const struct nnp_padding input_padding = {
|
||||
.top = (size_t)padding[0],
|
||||
.right = (size_t)padding[1],
|
||||
.bottom = (size_t)padding[0],
|
||||
.left = (size_t)padding[1],
|
||||
.top = static_cast<size_t>(padding[0]),
|
||||
.right = static_cast<size_t>(padding[1]),
|
||||
.bottom = static_cast<size_t>(padding[0]),
|
||||
.left = static_cast<size_t>(padding[1]),
|
||||
};
|
||||
const struct nnp_size kernel_size = {
|
||||
.width = (size_t)weight.size(3),
|
||||
.height = (size_t)weight.size(2),
|
||||
.width = static_cast<size_t>(weight.size(3)),
|
||||
.height = static_cast<size_t>(weight.size(2)),
|
||||
};
|
||||
const struct nnp_size output_size = {
|
||||
.width = (size_t)output.size(3),
|
||||
.height = (size_t)output.size(2),
|
||||
.width = static_cast<size_t>(output.size(3)),
|
||||
.height = static_cast<size_t>(output.size(2)),
|
||||
};
|
||||
const nnp_size output_subsample = {
|
||||
.width = static_cast<std::size_t>(stride[1]),
|
||||
|
||||
@ -248,8 +248,8 @@ void slow_conv_transpose3d_out_cpu_template(
|
||||
Tensor weight = weight_.contiguous();
|
||||
Tensor bias = bias_.defined() ? bias_.contiguous() : bias_;
|
||||
|
||||
const int n_input_plane = (int)weight.size(0);
|
||||
const int n_output_plane = (int)weight.size(1);
|
||||
const auto n_input_plane = weight.size(0);
|
||||
const auto n_output_plane = weight.size(1);
|
||||
|
||||
bool is_batch = false;
|
||||
if (input.dim() == 4) {
|
||||
|
||||
@ -84,8 +84,8 @@ static std::vector<int64_t> aligned_size(
|
||||
DimnameList aligned_names,
|
||||
bool is_aligning_two_tensors) {
|
||||
std::vector<int64_t> expanded_sizes(aligned_names.size(), 1);
|
||||
ptrdiff_t dim = (ptrdiff_t)tensor_sizes.size() - 1;
|
||||
ptrdiff_t idx = (ptrdiff_t)aligned_names.size() - 1;
|
||||
ptrdiff_t dim = static_cast<ptrdiff_t>(tensor_sizes.size()) - 1;
|
||||
ptrdiff_t idx = static_cast<ptrdiff_t>(aligned_names.size()) - 1;
|
||||
for (; idx >= 0 && dim >= 0; --idx) {
|
||||
if (tensor_names[dim] != aligned_names[idx]) {
|
||||
continue;
|
||||
|
||||
@ -25,7 +25,7 @@ std::tuple<Tensor, Tensor> _rowwise_prune_helper(
|
||||
auto mask_contig = mask.contiguous();
|
||||
auto mask_data = mask_contig.data_ptr<bool>();
|
||||
for (const auto i : c10::irange(mask.numel())) {
|
||||
num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0);
|
||||
num_non_masked_rows += ((mask_data[i] == true) ? 1 : 0);
|
||||
}
|
||||
int num_cols = weights.size(1);
|
||||
auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols},
|
||||
|
||||
@ -176,7 +176,7 @@ void host_softmax(
|
||||
scalar_t* input_data_base = input.data_ptr<scalar_t>();
|
||||
scalar_t* output_data_base = output.data_ptr<scalar_t>();
|
||||
bool* mask_data_base = mask;
|
||||
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
|
||||
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast<int64_t>(1));
|
||||
parallel_for(
|
||||
0, outer_size * inner_size, grain_size,
|
||||
[&](int64_t begin, int64_t end) {
|
||||
@ -265,7 +265,7 @@ void host_softmax_backward(
|
||||
scalar_t* output_data_base = output.data_ptr<scalar_t>();
|
||||
scalar_t* gradOutput_data_base = grad.data_ptr<scalar_t>();
|
||||
bool* mask_data_base = mask;
|
||||
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1);
|
||||
int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast<int64_t>(1));
|
||||
parallel_for(
|
||||
0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
|
||||
@ -1701,13 +1701,13 @@ Tensor& index_select_out_cpu_(
|
||||
TORCH_CHECK_INDEX(
|
||||
(self_i >= 0) && (self_i < self_dim_size),
|
||||
"index out of range in self");
|
||||
auto self_data = static_cast<const char*>(selfSlice_data) +
|
||||
auto self_data = const_cast<char*>(static_cast<const char*>(
|
||||
selfSlice_data)) +
|
||||
self_i * self_stride_bytes;
|
||||
auto result_data = static_cast<char*>(resultSlice_data) +
|
||||
i * result_stride_bytes;
|
||||
sub_iter.unsafe_replace_operand(0, result_data);
|
||||
sub_iter.unsafe_replace_operand(
|
||||
1, const_cast<char*>(self_data));
|
||||
sub_iter.unsafe_replace_operand(1, self_data);
|
||||
copy_stub(sub_iter.device_type(), sub_iter, false);
|
||||
};
|
||||
});
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <ATen/SparseCsrTensorUtils.h>
|
||||
#include <ATen/TensorOperators.h>
|
||||
#include <ATen/TracerMode.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/UnaryOps.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
@ -1089,6 +1090,7 @@ Tensor& rand_out(
|
||||
|
||||
Tensor rand_like(
|
||||
const Tensor& self,
|
||||
std::optional<Generator> generator,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
@ -1100,7 +1102,24 @@ Tensor rand_like(
|
||||
pin_memory);
|
||||
|
||||
auto result = at::empty_like(self, options, optional_memory_format);
|
||||
return result.uniform_(0, 1, std::nullopt);
|
||||
return result.uniform_(0, 1, std::move(generator));
|
||||
}
|
||||
|
||||
Tensor rand_like(
|
||||
const Tensor& self,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
std::optional<bool> pin_memory,
|
||||
std::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
return native::rand_like(
|
||||
self,
|
||||
static_cast<std::optional<Generator>>(std::nullopt),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
optional_memory_format);
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -1197,7 +1216,9 @@ Tensor& randint_out(
|
||||
|
||||
Tensor randint_like(
|
||||
const Tensor& self,
|
||||
int64_t low,
|
||||
int64_t high,
|
||||
std::optional<Generator> generator,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
@ -1209,7 +1230,71 @@ Tensor randint_like(
|
||||
pin_memory);
|
||||
|
||||
auto result = at::empty_like(self, options, optional_memory_format);
|
||||
return result.random_(0, high, std::nullopt);
|
||||
return result.random_(low, high, std::move(generator));
|
||||
}
|
||||
|
||||
Tensor randint_like(
|
||||
const Tensor& self,
|
||||
int64_t low,
|
||||
int64_t high,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
std::optional<bool> pin_memory,
|
||||
std::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
return native::randint_like(
|
||||
self,
|
||||
low,
|
||||
high,
|
||||
static_cast<std::optional<Generator>>(std::nullopt),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
optional_memory_format);
|
||||
}
|
||||
|
||||
Tensor randint_like(
|
||||
const Tensor& self,
|
||||
int64_t high,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
std::optional<bool> pin_memory,
|
||||
std::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
// See [Note: hacky wrapper removal for TensorOptions]
|
||||
return native::randint_like(
|
||||
self,
|
||||
0,
|
||||
high,
|
||||
static_cast<std::optional<Generator>>(std::nullopt),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
optional_memory_format);
|
||||
}
|
||||
|
||||
Tensor randint_like(
|
||||
const Tensor& self,
|
||||
int64_t high,
|
||||
std::optional<Generator> generator,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
std::optional<bool> pin_memory,
|
||||
std::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
// See [Note: hacky wrapper removal for TensorOptions]
|
||||
return native::randint_like(
|
||||
self,
|
||||
0,
|
||||
high,
|
||||
generator,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
optional_memory_format);
|
||||
}
|
||||
|
||||
Tensor randint_like(
|
||||
@ -1226,7 +1311,9 @@ Tensor randint_like(
|
||||
int64_t high_scalar = high.item<int64_t>();
|
||||
return at::native::randint_like(
|
||||
self,
|
||||
0,
|
||||
high_scalar,
|
||||
static_cast<std::optional<Generator>>(std::nullopt),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
@ -1236,20 +1323,27 @@ Tensor randint_like(
|
||||
|
||||
Tensor randint_like(
|
||||
const Tensor& self,
|
||||
int64_t low,
|
||||
int64_t high,
|
||||
const Tensor& high,
|
||||
std::optional<Generator> generator,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
std::optional<bool> pin_memory,
|
||||
std::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
// See [Note: hacky wrapper removal for TensorOptions]
|
||||
TensorOptions options =
|
||||
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
|
||||
pin_memory);
|
||||
|
||||
auto result = at::empty_like(self, options, optional_memory_format);
|
||||
return result.random_(low, high, std::nullopt);
|
||||
TORCH_CHECK(
|
||||
high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(),
|
||||
"high must be a scalar tensor and on CPU");
|
||||
int64_t high_scalar = high.item<int64_t>();
|
||||
return at::native::randint_like(
|
||||
self,
|
||||
0,
|
||||
high_scalar,
|
||||
generator,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
optional_memory_format);
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -1327,6 +1421,7 @@ Tensor& normal_out(
|
||||
|
||||
Tensor randn_like(
|
||||
const Tensor& self,
|
||||
std::optional<Generator> generator,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
@ -1338,7 +1433,24 @@ Tensor randn_like(
|
||||
pin_memory);
|
||||
|
||||
auto result = at::empty_like(self, options, optional_memory_format);
|
||||
return result.normal_(0, 1, std::nullopt);
|
||||
return result.normal_(0, 1, std::move(generator));
|
||||
}
|
||||
|
||||
Tensor randn_like(
|
||||
const Tensor& self,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
std::optional<bool> pin_memory,
|
||||
std::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
return native::randn_like(
|
||||
self,
|
||||
static_cast<std::optional<Generator>>(std::nullopt),
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
optional_memory_format);
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randperm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -1382,7 +1494,7 @@ void randperm_cpu(Tensor& result, int64_t n, CPUGeneratorImpl* generator) {
|
||||
// use no-initialization Fischer-Yates variant
|
||||
// https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_.22inside-out.22_algorithm
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
int64_t z = (int64_t)(generator->random64() % (i + 1));
|
||||
int64_t z = static_cast<int64_t>(generator->random64() % (i + 1));
|
||||
r__data[i * r__stride_0] = i;
|
||||
r__data[i * r__stride_0] = r__data[z * r__stride_0];
|
||||
r__data[z * r__stride_0] = i;
|
||||
|
||||
@ -40,7 +40,7 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl<false>(
|
||||
"quantized_sparse_linear(): Input tensor rank should be >= 2");
|
||||
|
||||
const auto rows_input = c10::multiply_integers(input.sizes().begin(), input.sizes().end() - 1);
|
||||
const auto cols_input = static_cast<int64_t>(input.size(input.dim() - 1));
|
||||
const auto cols_input = input.size(input.dim() - 1);
|
||||
TORCH_CHECK(
|
||||
cols_input == input_channels_,
|
||||
"quantized_sparse_linear: Input tensor's last and weight tensor's"
|
||||
|
||||
@ -65,8 +65,8 @@ LinearPackedSerializationType PackedLinearWeight::unpack() {
|
||||
#ifdef USE_PYTORCH_QNNPACK
|
||||
|
||||
LinearPackedSerializationType PackedLinearWeightQnnp::unpack() {
|
||||
const int64_t N = static_cast<int64_t>(output_channels_);
|
||||
const int64_t K = static_cast<int64_t>(input_channels_);
|
||||
const int64_t N = output_channels_;
|
||||
const int64_t K = input_channels_;
|
||||
|
||||
float* w_scales_ptr = w_scales_.data_ptr<float>();
|
||||
|
||||
|
||||
@ -998,7 +998,7 @@ void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, con
|
||||
auto threshold = threshold_.to<float>();
|
||||
const Vec beta_vec(beta);
|
||||
const Vec threshold_vec(threshold);
|
||||
const Vec one_vec(static_cast<float>(1.0));
|
||||
const Vec one_vec(1.0f);
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[beta, threshold](scalar_t a, scalar_t b) -> scalar_t {
|
||||
|
||||
@ -17,7 +17,7 @@ static inline void cpu_atomic_add_float(float* dst, float fvalue)
|
||||
} uf32_t;
|
||||
|
||||
uf32_t new_value, old_value;
|
||||
std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)(dst);
|
||||
std::atomic<unsigned>* dst_intV = (std::atomic<unsigned>*)dst;
|
||||
|
||||
old_value.floatV = *dst;
|
||||
new_value.floatV = old_value.floatV + fvalue;
|
||||
|
||||
@ -851,7 +851,7 @@ void sigmoid_backward_kernel(TensorIteratorBase& iter) {
|
||||
});
|
||||
});
|
||||
} else if (iter.dtype() == kBFloat16) {
|
||||
auto one_vec = Vectorized<float>((float)(1));
|
||||
auto one_vec = Vectorized<float>((float)1);
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](BFloat16 a, BFloat16 b) -> BFloat16 {
|
||||
|
||||
@ -77,9 +77,7 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
|
||||
|
||||
int64_t grain_size = at::internal::GRAIN_SIZE;
|
||||
|
||||
auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) {
|
||||
std::array<char*, 2> data;
|
||||
std::copy_n(base, 2, data.data());
|
||||
auto loop = [strides_in, requires_neg](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
|
||||
const int64_t *outer_strides = &strides[2];
|
||||
|
||||
for ([[maybe_unused]] const auto it : c10::irange(size1)) {
|
||||
@ -146,9 +144,7 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne
|
||||
|
||||
int64_t grain_size = at::internal::GRAIN_SIZE;
|
||||
|
||||
auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) {
|
||||
std::array<char*, 2> data;
|
||||
std::copy_n(base, 2, data.data());
|
||||
auto loop = [strides_in, requires_neg](char** data, const int64_t* strides, int64_t size0, int64_t size1) {
|
||||
const int64_t *outer_strides = &strides[2];
|
||||
|
||||
for ([[maybe_unused]] const auto it : c10::irange(size1)) {
|
||||
|
||||
@ -493,40 +493,33 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
|
||||
|
||||
for ([[maybe_unused]] const auto j : c10::irange(size1)) {
|
||||
// vectorized loop with negative stride for output
|
||||
char** C10_RESTRICT data_ = data_arr.data();
|
||||
int64_t n = size0;
|
||||
|
||||
char* C10_RESTRICT data[ntensors];
|
||||
for (const auto arg : c10::irange(ntensors)) {
|
||||
data[arg] = data_[arg];
|
||||
}
|
||||
|
||||
int64_t i = 0;
|
||||
|
||||
// data[0] unaligned pre-pass
|
||||
// data_arr[0] unaligned pre-pass
|
||||
int64_t offset = (j * n + (n - i - Vec::size())) % 32;
|
||||
offset = (offset >= n) ? n : offset;
|
||||
for (; i < offset; i++) {
|
||||
scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
|
||||
*out_ptr = c10::load((scalar_t *)(data[1] + i * stride));
|
||||
scalar_t* out_ptr = (scalar_t*)(data_arr[0] - i * stride);
|
||||
*out_ptr = c10::load((scalar_t *)(data_arr[1] + i * stride));
|
||||
}
|
||||
// Empirically found that it is faster to process 3 data items together vs 2 or 4
|
||||
for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) {
|
||||
auto out1 = Vec::loadu(data[1] + i * stride);
|
||||
auto out2 = Vec::loadu(data[1] + (i + Vec::size()) * stride);
|
||||
auto out3 = Vec::loadu(data[1] + (i + 2 * Vec::size()) * stride);
|
||||
auto out1 = Vec::loadu(data_arr[1] + i * stride);
|
||||
auto out2 = Vec::loadu(data_arr[1] + (i + Vec::size()) * stride);
|
||||
auto out3 = Vec::loadu(data_arr[1] + (i + 2 * Vec::size()) * stride);
|
||||
// flip the vector: 1234 -> 4321
|
||||
out1 = flip(out1);
|
||||
out2 = flip(out2);
|
||||
out3 = flip(out3);
|
||||
out1.store(data[0] - (i + Vec::size() - 1) * stride);
|
||||
out2.store(data[0] - (i + 2 * Vec::size() - 1) * stride);
|
||||
out3.store(data[0] - (i + 3 * Vec::size() - 1) * stride);
|
||||
out1.store(data_arr[0] - (i + Vec::size() - 1) * stride);
|
||||
out2.store(data_arr[0] - (i + 2 * Vec::size() - 1) * stride);
|
||||
out3.store(data_arr[0] - (i + 3 * Vec::size() - 1) * stride);
|
||||
}
|
||||
if (i < n) {
|
||||
for (; i < n; i++) {
|
||||
scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride);
|
||||
*out_ptr = c10::load((scalar_t *)(data[1] + i * stride));
|
||||
scalar_t* out_ptr = (scalar_t*)(data_arr[0] - i * stride);
|
||||
*out_ptr = c10::load((scalar_t *)(data_arr[1] + i * stride));
|
||||
}
|
||||
}
|
||||
|
||||
@ -560,15 +553,8 @@ void cpu_vflip_memcpy(at::TensorIterator& iter) {
|
||||
const int64_t stride = strides[0];
|
||||
|
||||
for ([[maybe_unused]] const auto j : c10::irange(size1)) {
|
||||
char** C10_RESTRICT data_ = data_arr.data();
|
||||
int64_t n = size0;
|
||||
|
||||
char* C10_RESTRICT data[ntensors];
|
||||
for (const auto arg : c10::irange(ntensors)) {
|
||||
data[arg] = data_[arg];
|
||||
}
|
||||
|
||||
memcpy(data[0], data[1], n * stride);
|
||||
memcpy(data_arr[0], data_arr[1], n * stride);
|
||||
|
||||
// advance:
|
||||
for (const auto arg : c10::irange(data_arr.size())) {
|
||||
|
||||
@ -92,7 +92,8 @@ void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
|
||||
void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, double beta) {
|
||||
ScalarType dtype = iter.dtype(0);
|
||||
if (dtype == kBFloat16) {
|
||||
if (at::isReducedFloatingType(dtype)) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "smooth_l1_backward_cpu_out", [&]() {
|
||||
auto norm_val = norm.to<float>();
|
||||
float beta_val(beta);
|
||||
auto norm_val_vec = Vectorized<float>(norm_val);
|
||||
@ -101,9 +102,9 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
const auto zero_vec = Vectorized<float>(0);
|
||||
const auto pos_1_vec = Vectorized<float>(1);
|
||||
cpu_kernel_vec(iter,
|
||||
[=](BFloat16 input, BFloat16 target, BFloat16 grad_output) -> BFloat16 {
|
||||
[=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
|
||||
const auto x = float(input) - float(target);
|
||||
if (x <= -beta){
|
||||
if (x <= -beta) {
|
||||
return -norm_val * float(grad_output);
|
||||
}else if (x >= beta){
|
||||
return norm_val * float(grad_output);
|
||||
@ -112,14 +113,14 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
}
|
||||
},
|
||||
[norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
|
||||
Vectorized<BFloat16> input, Vectorized<BFloat16> target, Vectorized<BFloat16> grad_output) -> Vectorized<BFloat16> {
|
||||
Vectorized<scalar_t> input, Vectorized<scalar_t> target, Vectorized<scalar_t> grad_output) -> Vectorized<scalar_t> {
|
||||
// using two blendv calls to simulate the 3 cases
|
||||
// 1 if x >= beta
|
||||
// -1 if x <= -beta
|
||||
// x / beta if |x| < beta
|
||||
auto [input0, input1] = convert_bfloat16_float(input);
|
||||
auto [target0, target1] = convert_bfloat16_float(target);
|
||||
auto [grad_output0, grad_output1] = convert_bfloat16_float(grad_output);
|
||||
auto [input0, input1] = convert_to_float(input);
|
||||
auto [target0, target1] = convert_to_float(target);
|
||||
auto [grad_output0, grad_output1] = convert_to_float(grad_output);
|
||||
auto x = input0 - target0;
|
||||
auto pos_or_neg_1_vec = Vectorized<float>::blendv(
|
||||
neg_1_vec, pos_1_vec, x > zero_vec);
|
||||
@ -135,11 +136,12 @@ void smooth_l1_backward_cpu_kernel(TensorIterator& iter, const Scalar& norm, dou
|
||||
output = Vectorized<float>::blendv(
|
||||
x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
|
||||
input1 = norm_val_vec * output * grad_output1;
|
||||
return convert_float_bfloat16(input0, input1);
|
||||
return convert_from_float<scalar_t>(input0, input1);
|
||||
}
|
||||
);
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND(kHalf, dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
|
||||
auto norm_val = norm.to<scalar_t>();
|
||||
scalar_t beta_val(beta);
|
||||
auto norm_val_vec = Vectorized<scalar_t>(norm_val);
|
||||
|
||||
@ -298,7 +298,7 @@ void unfolded2d_copy(
|
||||
memcpy(
|
||||
dst + (size_t)y * output_width + x,
|
||||
src + (size_t)iy * input_width + ix,
|
||||
sizeof(scalar_t) * (1));
|
||||
sizeof(scalar_t) * 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -317,7 +317,7 @@ void unfolded2d_copy(
|
||||
memcpy(
|
||||
dst + (size_t)y * output_width + x,
|
||||
src + (size_t)iy * input_width + ix + x * dW,
|
||||
sizeof(scalar_t) * (1));
|
||||
sizeof(scalar_t) * 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -342,7 +342,7 @@ void upsample_avx_bilinear_bicubic_uint8(
|
||||
|
||||
if (need_horizontal) {
|
||||
int interp_dim = 3;
|
||||
auto stride = (skip_unpacking) ? num_channels : 4;
|
||||
auto stride = skip_unpacking ? num_channels : 4;
|
||||
std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
|
||||
F::compute_index_ranges_int16_weights(
|
||||
/*input_size=*/xin,
|
||||
@ -358,7 +358,7 @@ void upsample_avx_bilinear_bicubic_uint8(
|
||||
|
||||
if (need_vertical) {
|
||||
int interp_dim = 2;
|
||||
auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout;
|
||||
auto stride = skip_unpacking ? num_channels * xout : 4 * xout;
|
||||
std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
|
||||
F::compute_index_ranges_int16_weights(
|
||||
/*input_size=*/yin,
|
||||
@ -377,17 +377,17 @@ void upsample_avx_bilinear_bicubic_uint8(
|
||||
// horizontal-only or vertical-only interpolation, and if the tensor doesn't
|
||||
// need repacking
|
||||
if (need_horizontal && (need_vertical || !skip_packing)) {
|
||||
auto c = (skip_unpacking) ? num_channels : 4;
|
||||
auto c = skip_unpacking ? num_channels : 4;
|
||||
buffer_horiz = at::empty({c, yin, xout}, input.options());
|
||||
}
|
||||
if (need_vertical && !skip_packing) {
|
||||
auto c = (skip_unpacking) ? num_channels : 4;
|
||||
auto c = skip_unpacking ? num_channels : 4;
|
||||
buffer_vert = at::empty({c, yout, xout}, input.options());
|
||||
}
|
||||
|
||||
for (const auto i : c10::irange(batch_size)) {
|
||||
|
||||
at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]);
|
||||
at::Tensor unpacked_input = skip_unpacking ? input[i] : unpack_rgb(input[i]);
|
||||
at::Tensor unpacked_output;
|
||||
|
||||
if (need_horizontal) {
|
||||
@ -411,7 +411,7 @@ void upsample_avx_bilinear_bicubic_uint8(
|
||||
unpacked_output = unpacked_input = unpacked_output_temp;
|
||||
}
|
||||
if (need_vertical) {
|
||||
unpacked_output = (skip_packing) ? output[i] : buffer_vert;
|
||||
unpacked_output = skip_packing ? output[i] : buffer_vert;
|
||||
|
||||
ImagingResampleVertical(
|
||||
unpacked_output,
|
||||
@ -502,7 +502,7 @@ void ImagingResampleHorizontalConvolution8u4x(
|
||||
// RGBA: b4_delta = b4_delta_soft = 3
|
||||
// RGB : b4_delta = 5
|
||||
// RGB : b4_delta_soft = 4
|
||||
const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
|
||||
const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4);
|
||||
|
||||
// In block 2 (2 means we process 2 weights values together), we read input data
|
||||
// with _mm_loadl_epi64, i.e. 8 bytes, per one line:
|
||||
@ -515,7 +515,7 @@ void ImagingResampleHorizontalConvolution8u4x(
|
||||
// RGBA: b2_delta = b2_delta_soft = 1
|
||||
// RGB : b2_delta = 2
|
||||
// RGB : b2_delta_soft = 1
|
||||
const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
|
||||
const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1);
|
||||
|
||||
const auto max_out_x_strided = out_xsize * stride;
|
||||
const auto max_in_x_strided = in_xsize * stride;
|
||||
@ -819,7 +819,7 @@ void ImagingResampleHorizontalConvolution8u(
|
||||
// RGBA: b8_delta = b8_delta_soft = 7
|
||||
// RGB : b8_delta = 10
|
||||
// RGB : b8_delta_soft = 9
|
||||
const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9);
|
||||
const auto b8_delta = (stride == 4) ? 7 : (is_last_line ? 10 : 9);
|
||||
|
||||
// In block 4 (4 means we process 4 weight values together), we read
|
||||
// 16 bytes of input data.
|
||||
@ -832,7 +832,7 @@ void ImagingResampleHorizontalConvolution8u(
|
||||
// RGBA: b4_delta = b4_delta_soft = 3
|
||||
// RGB : b4_delta = 5
|
||||
// RGB : b4_delta_soft = 4
|
||||
const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4);
|
||||
const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4);
|
||||
|
||||
// In block 2 (2 means we process 2 weight values together), we read
|
||||
// 8 bytes of input data.
|
||||
@ -845,7 +845,7 @@ void ImagingResampleHorizontalConvolution8u(
|
||||
// RGBA: b2_delta = b2_delta_soft = 1
|
||||
// RGB : b2_delta = 2
|
||||
// RGB : b2_delta_soft = 1
|
||||
const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1);
|
||||
const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1);
|
||||
|
||||
const auto max_out_x_strided = out_xsize * stride;
|
||||
const auto max_in_x_strided = in_xsize * stride;
|
||||
|
||||
@ -644,8 +644,8 @@ void weight_to_int4pack_kernel(
|
||||
int32_t val2 = src[(d + 32) * K + k];
|
||||
int32_t val3 = src[(d + 48) * K + k];
|
||||
|
||||
uint8_t packed02 = (((uint8_t)(val2) << 4)) | ((uint8_t)(val0));
|
||||
uint8_t packed13 = (((uint8_t)(val3) << 4)) | ((uint8_t)(val1));
|
||||
uint8_t packed02 = ((uint8_t)val2 << 4) | ((uint8_t)val0);
|
||||
uint8_t packed13 = ((uint8_t)val3 << 4) | ((uint8_t)val1);
|
||||
|
||||
dst[k * 32 + d] = packed02;
|
||||
dst[k * 32 + 16 + d] = packed13;
|
||||
@ -656,7 +656,7 @@ void weight_to_int4pack_kernel(
|
||||
int32_t val0 = src[n * K + k];
|
||||
int32_t val1 = src[n * K + K + k];
|
||||
|
||||
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
|
||||
uint8_t packed = ((uint8_t)val1 << 4) | ((uint8_t)val0);
|
||||
dst[k * nb_size / 2 + n / 2] = packed;
|
||||
}
|
||||
}
|
||||
@ -667,7 +667,7 @@ void weight_to_int4pack_kernel(
|
||||
int32_t val0 = src[(d + 0) * K + k];
|
||||
int32_t val1 = src[(d + 16) * K + k];
|
||||
|
||||
uint8_t packed01 = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
|
||||
uint8_t packed01 = ((uint8_t)val1 << 4) | ((uint8_t)val0);
|
||||
dst[k * 16 + d] = packed01;
|
||||
}
|
||||
} else {
|
||||
@ -676,7 +676,7 @@ void weight_to_int4pack_kernel(
|
||||
int32_t val0 = src[n * K + k];
|
||||
int32_t val1 = src[n * K + K + k];
|
||||
|
||||
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
|
||||
uint8_t packed = ((uint8_t)val1 << 4) | ((uint8_t)val0);
|
||||
dst[k * nb_size / 2 + n / 2] = packed;
|
||||
}
|
||||
}
|
||||
@ -685,7 +685,7 @@ void weight_to_int4pack_kernel(
|
||||
int32_t val0 = src[n * K + k];
|
||||
int32_t val1 = src[n * K + K + k];
|
||||
|
||||
uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0));
|
||||
uint8_t packed = ((uint8_t)val1 << 4) | ((uint8_t)val0);
|
||||
dst[k * nb_size / 2 + n / 2] = packed;
|
||||
}
|
||||
#endif
|
||||
@ -872,16 +872,16 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
|
||||
max0 = (std::max)(src0_0, max0);
|
||||
min0 = (std::min)(src0_0, min0);
|
||||
max0 = std::max(src0_0, max0);
|
||||
min0 = std::min(src0_0, min0);
|
||||
}
|
||||
|
||||
// Maximum/minimum int8 values
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
|
||||
const float rmin0 = (std::min)(0.0f, min0);
|
||||
const float rmax0 = (std::max)(0.0f, max0);
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
|
||||
const float scale0 =
|
||||
rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
|
||||
@ -900,8 +900,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
? qmin - descaled_min0
|
||||
: qmax - descaled_max0;
|
||||
|
||||
zero_point0 = (std::max)(zero_point0, qmin);
|
||||
zero_point0 = (std::min)(zero_point0, qmax);
|
||||
zero_point0 = std::max(zero_point0, qmin);
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
|
||||
// Round to nearest integer
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
@ -909,9 +909,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride;
|
||||
|
||||
// LHS offset at the beginning of the row
|
||||
*((float*)(dst_ptr)) = recip_scale0;
|
||||
*((float*)dst_ptr) = recip_scale0;
|
||||
dst_ptr += sizeof(float);
|
||||
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
|
||||
*((int32_t*)dst_ptr) = -nudged_zero_point0;
|
||||
dst_ptr += sizeof(int32_t);
|
||||
|
||||
// Quantize the channels
|
||||
@ -922,8 +922,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
|
||||
v0_s32 = v0_s32 + nudged_zero_point0;
|
||||
v0_s32 = (std::max)(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||
v0_s32 = (std::min)(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||
v0_s32 = std::max(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||
v0_s32 = std::min(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
@ -988,8 +988,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
// Clamp (min-max) operation
|
||||
main_acc = (std::max)(main_acc, scalar_min);
|
||||
main_acc = (std::min)(main_acc, scalar_max);
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
|
||||
dst_f32[0] = main_acc;
|
||||
dst_f32 += 1;
|
||||
@ -1024,15 +1024,15 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
max0 = (std::max)(src0_0, max0);
|
||||
min0 = (std::min)(src0_0, min0);
|
||||
max0 = std::max(src0_0, max0);
|
||||
min0 = std::min(src0_0, min0);
|
||||
}
|
||||
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
|
||||
const float rmin0 = (std::min)(0.0f, min0);
|
||||
const float rmax0 = (std::max)(0.0f, max0);
|
||||
const float rmin0 = std::min(0.0f, min0);
|
||||
const float rmax0 = std::max(0.0f, max0);
|
||||
const float scale0 =
|
||||
(rmin0 == rmax0) ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
|
||||
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
|
||||
@ -1044,22 +1044,22 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
? qmin - descaled_min0
|
||||
: qmax - descaled_max0;
|
||||
|
||||
zero_point0 = (std::max)(zero_point0, qmin);
|
||||
zero_point0 = (std::min)(zero_point0, qmax);
|
||||
zero_point0 = std::max(zero_point0, qmin);
|
||||
zero_point0 = std::min(zero_point0, qmax);
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride;
|
||||
|
||||
*((float*)(dst_ptr)) = recip_scale0;
|
||||
*((float*)dst_ptr) = recip_scale0;
|
||||
dst_ptr += sizeof(float);
|
||||
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
|
||||
*((int32_t*)dst_ptr) = -nudged_zero_point0;
|
||||
dst_ptr += sizeof(int32_t);
|
||||
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
v0_s32 = (std::max)(
|
||||
(std::min)(
|
||||
v0_s32 = std::max(
|
||||
std::min(
|
||||
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
|
||||
static_cast<int32_t>(INT8_MIN));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
@ -1118,8 +1118,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
}
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
main_acc = (std::max)(main_acc, scalar_min);
|
||||
main_acc = (std::min)(main_acc, scalar_max);
|
||||
main_acc = std::max(main_acc, scalar_min);
|
||||
main_acc = std::min(main_acc, scalar_max);
|
||||
|
||||
dst_f32[0] = main_acc;
|
||||
dst_f32 += 1;
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/NamedTensor.h>
|
||||
@ -206,8 +205,8 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
|
||||
// and the leading stride is at least max(1, other dim length), so we might
|
||||
// end up with contiguous cols but not rows (i.e. holes between different rows)
|
||||
// and vice versa.
|
||||
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 &&
|
||||
mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 &&
|
||||
&& mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32
|
||||
&& mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32
|
||||
&& (
|
||||
// filter by dtype
|
||||
(scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) ||
|
||||
|
||||
@ -54,7 +54,6 @@ namespace {
|
||||
using DtypeScale = float;
|
||||
using DtypeAccum = float;
|
||||
using DtypeEpilogue = float;
|
||||
using DtypeOutput = cutlass::bfloat16_t;
|
||||
|
||||
using Multiply = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
@ -68,12 +67,6 @@ using Add = cutlass::epilogue::fusion::Sm90Compute<
|
||||
DtypeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using Cast = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity,
|
||||
DtypeOutput,
|
||||
DtypeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
template <bool LargeTile, bool FastAccum>
|
||||
struct Schedule;
|
||||
|
||||
@ -120,7 +113,8 @@ template <
|
||||
typename FastAccum,
|
||||
typename DtypeA,
|
||||
typename DtypeB,
|
||||
typename DtypeBias>
|
||||
typename DtypeBias,
|
||||
typename DtypeOutput>
|
||||
void f8f8bf16_rowwise_impl(
|
||||
at::Tensor XQ, // FP8
|
||||
at::Tensor WQ, // FP8
|
||||
@ -181,6 +175,11 @@ void f8f8bf16_rowwise_impl(
|
||||
WScale,
|
||||
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
|
||||
|
||||
using Cast = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity,
|
||||
DtypeOutput,
|
||||
DtypeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Cast,
|
||||
cutlass::epilogue::fusion::Sm90EVT<
|
||||
@ -313,7 +312,8 @@ template <
|
||||
typename FastAccum,
|
||||
typename DtypeA,
|
||||
typename DtypeB,
|
||||
typename DtypeBias>
|
||||
typename DtypeBias,
|
||||
typename DtypeOutput>
|
||||
void f8f8bf16_rowwise_impl_sm100_sm120(
|
||||
at::Tensor XQ, // FP8
|
||||
at::Tensor WQ, // FP8
|
||||
@ -372,6 +372,11 @@ void f8f8bf16_rowwise_impl_sm100_sm120(
|
||||
WScale,
|
||||
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
|
||||
|
||||
using Cast = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::epilogue::thread::Identity,
|
||||
DtypeOutput,
|
||||
DtypeEpilogue,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Cast,
|
||||
cutlass::epilogue::fusion::Sm90EVT<
|
||||
@ -498,7 +503,8 @@ template <
|
||||
typename FastAccum,
|
||||
typename DtypeA,
|
||||
typename DtypeB,
|
||||
typename DtypeBias>
|
||||
typename DtypeBias,
|
||||
typename DtypeOutput>
|
||||
void f8f8bf16_rowwise_impl_sm89(
|
||||
at::Tensor XQ, // FP8
|
||||
at::Tensor WQ, // FP8
|
||||
@ -765,7 +771,8 @@ template <
|
||||
typename FastAccum,
|
||||
typename DtypeA,
|
||||
typename DtypeB,
|
||||
typename DtypeBias>
|
||||
typename DtypeBias,
|
||||
typename DtypeOutput>
|
||||
void handle_transposition(
|
||||
at::Tensor XQ,
|
||||
at::Tensor WQ,
|
||||
@ -782,7 +789,8 @@ void handle_transposition(
|
||||
FastAccum,
|
||||
DtypeA,
|
||||
DtypeB,
|
||||
DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
|
||||
DtypeBias,
|
||||
DtypeOutput>(XQ, WQ, x_scale, w_scale, bias, out, swizzle);
|
||||
} else {
|
||||
dispatch_fp8_rowwise_kernel_on_tile_size<
|
||||
ClusterShape,
|
||||
@ -791,7 +799,8 @@ void handle_transposition(
|
||||
FastAccum,
|
||||
DtypeB,
|
||||
DtypeA,
|
||||
DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t(), swizzle);
|
||||
DtypeBias,
|
||||
DtypeOutput>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t(), swizzle);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1027,11 +1036,19 @@ void dispatch_fp8_rowwise_kernel_on_bias_dtype(
|
||||
at::Tensor out) {
|
||||
if (bias.has_value() && bias->dtype() == at::kBFloat16) {
|
||||
dispatch_fp8_rowwise_kernel_on_input_dtypes<
|
||||
cutlass::bfloat16_t,
|
||||
cutlass::bfloat16_t>
|
||||
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
|
||||
} else if (bias.has_value() && bias->dtype() == at::kHalf){
|
||||
TORCH_CHECK(out.dtype() == at::kHalf, "Output should be Float16 when bias is Float16");
|
||||
dispatch_fp8_rowwise_kernel_on_input_dtypes<
|
||||
cutlass::half_t,
|
||||
cutlass::half_t>
|
||||
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
|
||||
} else {
|
||||
dispatch_fp8_rowwise_kernel_on_input_dtypes<
|
||||
float>
|
||||
float,
|
||||
cutlass::bfloat16_t>
|
||||
//Types...>
|
||||
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
|
||||
}
|
||||
@ -1073,14 +1090,14 @@ void check_inputs(
|
||||
|
||||
if (bias.has_value()) {
|
||||
TORCH_CHECK(bias->device() == b.device());
|
||||
TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16);
|
||||
TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16 || bias->dtype() == at::kHalf);
|
||||
TORCH_CHECK(bias->dim() == 1);
|
||||
TORCH_CHECK(bias->size(0) == b.size(1));
|
||||
TORCH_CHECK(bias->stride(0) == 1);
|
||||
}
|
||||
|
||||
TORCH_CHECK(out.device() == a.device());
|
||||
TORCH_CHECK(out.dtype() == at::kBFloat16);
|
||||
TORCH_CHECK(out.dtype() == at::kBFloat16 || out.dtype() == at::kHalf);
|
||||
TORCH_CHECK(out.dim() == 2);
|
||||
TORCH_CHECK(out.size(0) == a.size(0));
|
||||
TORCH_CHECK(out.size(1) == b.size(1));
|
||||
|
||||
@ -59,6 +59,24 @@
|
||||
// forward declare
|
||||
class cublasCommonArgs;
|
||||
|
||||
#ifndef _WIN32
|
||||
namespace fbgemm_gpu {
|
||||
|
||||
// NOTE(slayton58): FBGemm_GPU kernels come from <fbgemm_gpu/torch_ops.h> within the FBGemm repo.
|
||||
// To update supported ops means a submodule bump, which is.. painful. Instead, we
|
||||
// can simply forward-declare the methods we want to use.. Works at least as a short-term
|
||||
// thing, but should still be fixed somewhere/somehow.
|
||||
at::Tensor f4f4bf16(
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
std::optional<at::Tensor>,
|
||||
bool use_mx);
|
||||
|
||||
} // namespace fbgemm_gpu
|
||||
#endif
|
||||
|
||||
using at::blas::ScalingType;
|
||||
using at::blas::SwizzleType;
|
||||
|
||||
@ -591,7 +609,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
|
||||
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|
||||
|| (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) {
|
||||
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
TORCH_CHECK_VALUE(out.dtype() == kBFloat16 || out.dtype() == kHalf, "Only bf16 and fp16 high precision output types are supported for row-wise scaling.");
|
||||
return _scaled_rowwise_rowwise(
|
||||
mat1,
|
||||
mat2,
|
||||
@ -736,7 +754,7 @@ _scaled_rowwise_rowwise(
|
||||
if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900)
|
||||
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
|
||||
|| (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) {
|
||||
TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling.");
|
||||
TORCH_CHECK_VALUE(out.dtype() == kBFloat16 || out.dtype() == kHalf, "Only bf16 and fp16 high precision output types are supported for row-wise scaling.");
|
||||
at::cuda::detail::f8f8bf16_rowwise(
|
||||
mat_a,
|
||||
mat_b,
|
||||
@ -767,33 +785,6 @@ _scaled_rowwise_rowwise(
|
||||
return out;
|
||||
}
|
||||
|
||||
// Check the shapes & sizes of scales for deepseek-style (1x128, 128x128) scaling.
|
||||
// Wraps check_size_stride for easier integration, correctly handles cases where a dimension of the scale == 1,
|
||||
// and strides become somewhat meaningless
|
||||
void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const ScalingType scale_type) {
|
||||
if (scale_type == ScalingType::BlockWise1x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 0, t.size(0), 1),
|
||||
"at dim=0 scale should have ", t.size(0), "elements and stride(0) ", 1, "if ", t.size(0), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
auto expected_size = ceil_div<int64_t>(t.size(1), 128);
|
||||
TORCH_CHECK_VALUE(check_size_stride(scale, 1, expected_size, t.size(0)),
|
||||
"at dim=1 scale should have ", expected_size, "elements and stride ", t.size(0), "if ", expected_size, " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
} else if (scale_type == ScalingType::BlockWise128x128) {
|
||||
TORCH_CHECK_VALUE(check_size_stride(
|
||||
scale,
|
||||
0,
|
||||
ceil_div<int64_t>(t.size(0), 128),
|
||||
ceil_div<int64_t>(t.size(1), 128)),
|
||||
"at dim=0 scale should have ", ceil_div<int64_t>(t.size(0), 128), "elements and stride(0) ", ceil_div<int64_t>(t.size(1), 128), "if ", ceil_div<int64_t>(t.size(0), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
TORCH_CHECK(check_size_stride(
|
||||
scale, 1, ceil_div<int64_t>(t.size(1), 128), 1),
|
||||
"at dim=1 scale should have ", ceil_div<int64_t>(t.size(1), 128), "elements and stride(1) ", 1, "if ", ceil_div<int64_t>(t.size(1), 128), " > 1 - Got: ",
|
||||
"shape=", scale.sizes(), ", stride=", scale.strides());
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
@ -806,7 +797,7 @@ _check_deepseek_support() {
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
CUBLAS_VERSION >= 120900 && cublasLtGetVersion() >= 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
@ -823,23 +814,61 @@ _scaled_block1x128_block1x128(
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
// check types
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(scale_a.size(1) == 1 && scale_b.stride(1) == 1)
|
||||
),
|
||||
"scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -861,24 +890,65 @@ _scaled_block128x128_block1x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", ceil_div<int64_t>(mat_a.sizes()[0], 128), " x ", ceil_div<int64_t>(mat_a.sizes()[1], 128), " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == ceil_div<int64_t>(mat_b.sizes()[0], 128) && scale_b.sizes()[1] == mat_b.sizes()[1] && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", ceil_div<int64_t>(mat_b.sizes()[0], 128), " x ", mat_b.sizes()[1], " Float elements, got ", scale_b.sizes())
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [round_up(K // 128, 4), M // 128], stride: [M // 128, 1]
|
||||
// Bs: [N x K // 128], stride: [1, N]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
const int64_t M = mat_a.sizes()[0];
|
||||
const int64_t K = mat_a.sizes()[1];
|
||||
const int64_t N = mat_b.sizes()[1];
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(M, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ",
|
||||
ceil_div<int64_t>(M, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == N &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", N, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == N ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", N, "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise128x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -900,24 +970,62 @@ _scaled_block1x128_block128x128(
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
// A: [M, K], B: [K, N] are FP8, scales are fp32
|
||||
// As: [M x K // 128], stride: [1, M]
|
||||
// Bs: [round_up(K // 128, 4) x N // 128], stride: [1, N // 128]
|
||||
TORCH_CHECK_VALUE(
|
||||
isFloat8Type(mat_a.scalar_type()) &&
|
||||
isFloat8Type(mat_b.scalar_type()),
|
||||
"mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()
|
||||
);
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", mat_a.sizes()[0], " x ", mat_a.sizes()[1] / 128, " Float elements, got ", scale_a.sizes())
|
||||
TORCH_CHECK_VALUE(scale_b.sizes()[0] == mat_b.sizes()[0] / 128 && scale_b.sizes()[1] == mat_b.sizes()[1] / 128 && scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", mat_b.sizes()[0] / 128, " x ", mat_b.sizes()[1] / 128, " Float elements, got ", scale_b.sizes())
|
||||
int64_t M = mat_a.size(0);
|
||||
int64_t K = mat_a.size(1);
|
||||
int64_t N = mat_b.size(1);
|
||||
|
||||
// scale_a shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.size(0) == M &&
|
||||
scale_a.size(1) == ceil_div<int64_t>(K, 128) &&
|
||||
scale_a.scalar_type() == kFloat,
|
||||
"scale_a must have shape ", M, " x ", ceil_div<int64_t>(K, 128), " Float elements, got ", scale_a.sizes()
|
||||
);
|
||||
// scale_a stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_a.stride(0) == 1 &&
|
||||
(
|
||||
scale_a.stride(1) == M ||
|
||||
(
|
||||
scale_a.size(1) == 1 &&
|
||||
scale_a.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_a must have strides (1, ", M, "); got ", scale_b.strides()
|
||||
);
|
||||
// scale_b shape
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.size(0) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) &&
|
||||
scale_b.size(1) == ceil_div<int64_t>(N, 128) &&
|
||||
scale_b.scalar_type() == kFloat,
|
||||
"scale_b must have shape ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), " x ", ceil_div<int64_t>(N, 128), " Float elements, got ", scale_b.sizes()
|
||||
);
|
||||
// scale_b stride
|
||||
TORCH_CHECK_VALUE(
|
||||
scale_b.stride(0) == 1 &&
|
||||
(
|
||||
scale_b.stride(1) == round_up<int64_t>(ceil_div<int64_t>(K, 128), 4) ||
|
||||
(
|
||||
scale_b.size(1) == 1 &&
|
||||
scale_b.stride(1) == 1
|
||||
)
|
||||
),
|
||||
"scale_b must have strides (1, ", round_up<int64_t>(ceil_div<int64_t>(K, 128), 4), "); got ", scale_b.strides()
|
||||
);
|
||||
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x128;
|
||||
auto scaling_choice_b = ScalingType::BlockWise128x128;
|
||||
|
||||
// Check scale strides (including stride=1 small cases)
|
||||
_check_deepseek_scale_stride(scale_a, mat_a, scaling_choice_a);
|
||||
_check_deepseek_scale_stride(scale_b.t(), mat_b.t(), scaling_choice_b);
|
||||
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
@ -997,26 +1105,47 @@ _scaled_mxfp4_mxfp4(
|
||||
const std::optional<Tensor>& bias,
|
||||
const c10::ScalarType out_dtype,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only");
|
||||
#endif
|
||||
#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI))
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only");
|
||||
#else
|
||||
// Restrictions:
|
||||
// A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32
|
||||
TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
|
||||
auto scale_a_elems = ceil_div<int64_t>(2 * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(2 * mat_b.size(1), 32) * mat_b.size(0);
|
||||
// Packed FP4 format means actual-K = 2 * reported-K -- adjust
|
||||
auto K_multiplier = 2;
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scale_a_elems = ceil_div<int64_t>(K_multiplier * mat_a.size(0), 32) * mat_a.size(1);
|
||||
auto scale_b_elems = ceil_div<int64_t>(K_multiplier * mat_b.size(1), 32) * mat_b.size(0);
|
||||
#else
|
||||
// NVIDIA
|
||||
auto scale_a_elems = round_up<int64_t>(mat_a.size(0), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_a.size(1), 32), 4);
|
||||
auto scale_b_elems = round_up<int64_t>(mat_b.size(1), 128) * round_up<int64_t>(ceil_div<int64_t>(K_multiplier * mat_b.size(0), 32), 4);
|
||||
#endif
|
||||
TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(),
|
||||
"For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel());
|
||||
TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(),
|
||||
"For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel());
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)");
|
||||
#else
|
||||
// NVIDIA
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format");
|
||||
TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format");
|
||||
#endif
|
||||
|
||||
TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(),
|
||||
"For Blockwise scaling both scales should be contiguous");
|
||||
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// AMD
|
||||
auto scaling_choice_a = ScalingType::BlockWise1x32;
|
||||
auto scaling_choice_b = ScalingType::BlockWise1x32;
|
||||
|
||||
@ -1031,11 +1160,30 @@ _scaled_mxfp4_mxfp4(
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 ||
|
||||
out.scalar_type() == ScalarType::Half,
|
||||
"Block-wise scaling only supports BFloat16 or Half output types");
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later");
|
||||
#endif
|
||||
|
||||
return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out);
|
||||
#else
|
||||
// NVIDIA
|
||||
// NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor,
|
||||
// but we have one we need to use. Two clear options are to copy into
|
||||
// our output (slow), or use a move-assignment-operator (faster).
|
||||
// However, the compiler can complain about the explicit move preventing
|
||||
// copy elision because the return from f4f4bf16 is a temporary object.
|
||||
// So we don't explicitly move, and trust the compiler here...
|
||||
// In the longer term this should be fixed on the FBGemm side.
|
||||
out = fbgemm_gpu::f4f4bf16(
|
||||
mat_a,
|
||||
mat_b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
std::nullopt, /* global_scale */
|
||||
true /* use_mx */
|
||||
);
|
||||
|
||||
return out;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -1160,17 +1308,20 @@ _scaled_mm_cuda_v2_out(
|
||||
mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")");
|
||||
}
|
||||
|
||||
// Handle fp4 packed-K dimension
|
||||
int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1;
|
||||
|
||||
TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1],
|
||||
" but got ", bias->numel());
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.sizes()[1] % 16 == 0,
|
||||
K_multiplier * mat_a.sizes()[1] % 16 == 0,
|
||||
"Expected trailing dimension of mat1 to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes()[0],
|
||||
"x",
|
||||
mat_a.sizes()[1],
|
||||
K_multiplier * mat_a.sizes()[1],
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x",
|
||||
mat_b.sizes()[1], ") must be divisible by 16");
|
||||
|
||||
// TODO(slayton): Existing checks, not sure if they should really be here.
|
||||
|
||||
@ -1881,6 +1881,8 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) {
|
||||
|
||||
REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
|
||||
#if !AT_MAGMA_ENABLED()
|
||||
@ -1955,8 +1957,6 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const
|
||||
#endif
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// This is a type dispatch function for 'apply_magma_eigh'
|
||||
// For small inputs result is computed on CPU
|
||||
void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) {
|
||||
@ -2019,10 +2019,10 @@ This is an in-place routine, content of 'input', 'values', 'vectors' is overwrit
|
||||
For more information see MAGMA's documentation for GEEV routine.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
|
||||
void apply_magma_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) {
|
||||
#if !AT_MAGMA_ENABLED()
|
||||
TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling PyTorch with MAGMA. "
|
||||
"Either transfer the tensor to the CPU before calling torch.linalg.eig or recompile with MAGMA.");
|
||||
TORCH_CHECK(false, "Calling torch.linalg.eig with MAGMA requires compiling PyTorch with MAGMA. "
|
||||
"Either transfer the tensor to the CPU before calling torch.linalg.eig or use cuSolver.");
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == at::kCPU);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU);
|
||||
@ -2076,22 +2076,44 @@ TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling
|
||||
#endif
|
||||
}
|
||||
|
||||
// This is a type dispatching helper function for 'apply_linalg_eig'
|
||||
// MAGMA wrapper: transfers tensors to CPU, calls apply_magma_eig, then copies results back.
|
||||
void linalg_eig_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors){
|
||||
// MAGMA doesn't have GPU interface for the eigendecomposition, and it forces us to transfer to CPU
|
||||
auto eigenvalues_cpu = eigenvalues.cpu();
|
||||
auto eigenvectors_cpu = eigenvectors.cpu();
|
||||
auto infos_cpu = infos.cpu();
|
||||
|
||||
Tensor input_cpu = at::empty(input.sizes(), input.options().device(kCPU));
|
||||
input_cpu.transpose_(-2, -1);
|
||||
input_cpu.copy_(input);
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{
|
||||
apply_magma_eig<scalar_t>(eigenvalues_cpu, eigenvectors_cpu, input_cpu, infos_cpu, compute_eigenvectors);
|
||||
});
|
||||
|
||||
eigenvalues.copy_(eigenvalues_cpu);
|
||||
eigenvectors.copy_(eigenvectors_cpu);
|
||||
infos.copy_(infos_cpu);
|
||||
}
|
||||
void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
|
||||
// This function calculates the non-symmetric eigendecomposition in-place
|
||||
// tensors should be in batched column major memory format
|
||||
// the content of eigenvalues, eigenvectors and infos is overwritten by 'apply_linalg_eig'
|
||||
// the content of eigenvalues, eigenvectors and infos is overwritten by 'linalg_eig_magma' or
|
||||
// 'linalg_eig_cusolver_xgeev' both geev routines modify the provided input matrix in-place, therefore we need a copy
|
||||
|
||||
// apply_linalg_eig modifies the provided input matrix in-place, therefore we need a copy
|
||||
// MAGMA doesn't have GPU interface for the eigendecomposition and it forces us to transfer 'input' to CPU
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
|
||||
Tensor input_working_copy = at::empty(input.sizes(), input.options().device(kCPU));
|
||||
input_working_copy.transpose_(-2, -1); // make input_working_copy to have Fortran contiguous memory layout
|
||||
input_working_copy.copy_(input);
|
||||
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{
|
||||
apply_linalg_eig<scalar_t>(eigenvalues, eigenvectors, input_working_copy, infos, compute_eigenvectors);
|
||||
});
|
||||
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
|
||||
auto preferred_backend = at::globalContext().linalgPreferredBackend();
|
||||
switch (preferred_backend) {
|
||||
case at::LinalgBackend::Cusolver:
|
||||
default:
|
||||
linalg_eig_cusolver_xgeev(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
|
||||
return;
|
||||
case at::LinalgBackend::Magma:
|
||||
break; // MAGMA path handled below
|
||||
}
|
||||
#endif
|
||||
linalg_eig_magma(eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel)
|
||||
|
||||
@ -753,8 +753,8 @@ static void apply_cholesky_cusolver_potrf_looped(const Tensor& self_working_copy
|
||||
handle, params, uplo, n, datatype,
|
||||
self_working_copy_ptr + i * matrix_stride,
|
||||
lda, datatype,
|
||||
(char*)workdata_device_ptr + i * worksize_device, worksize_device,
|
||||
(char*)workdata_host_ptr + i * worksize_host, worksize_host,
|
||||
static_cast<char*>(workdata_device_ptr) + i * worksize_device, worksize_device,
|
||||
static_cast<char*>(workdata_host_ptr) + i * worksize_host, worksize_host,
|
||||
infos_ptr + i
|
||||
);
|
||||
}
|
||||
@ -1625,6 +1625,126 @@ void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors,
|
||||
#endif
|
||||
}
|
||||
|
||||
// cuSOLVER Xgeev (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+)
|
||||
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
|
||||
|
||||
template <typename scalar_t>
|
||||
void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_cuda());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.is_cuda());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_cuda());
|
||||
|
||||
int n = cuda_int_cast(input.size(-1), "n");
|
||||
int lda = std::max<int>(1, n);
|
||||
auto batch_size = batchCount(input);
|
||||
|
||||
if (n == 0 || batch_size == 0) {
|
||||
// XGeev crashes on empty input, explicitly handle empty input
|
||||
auto values_shape = IntArrayRef(input.sizes().data(), input.dim() - 1);
|
||||
values.resize_(values_shape, MemoryFormat::Contiguous);
|
||||
values.zero_();
|
||||
|
||||
if (compute_eigenvectors) {
|
||||
vectors.resize_(input.sizes(), MemoryFormat::Contiguous);
|
||||
vectors.zero_();
|
||||
} else {
|
||||
vectors.resize_({0});
|
||||
}
|
||||
|
||||
infos.resize_({std::max<int64_t>(1, batch_size)}, MemoryFormat::Contiguous);
|
||||
infos.zero_();
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t vectors_stride = 0;
|
||||
if (compute_eigenvectors){
|
||||
vectors_stride = matrixStride(vectors);
|
||||
}
|
||||
|
||||
auto values_stride = values.size(-1);
|
||||
auto vectors_data = vectors.data_ptr<scalar_t>();
|
||||
auto values_data = values.data_ptr<scalar_t>();
|
||||
auto infos_data = infos.data_ptr<int>();
|
||||
|
||||
cusolverDnParams_t params = nullptr;
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(¶ms));
|
||||
|
||||
Tensor A_fortran = input.mT().contiguous();
|
||||
auto* A_data = A_fortran.data_ptr<scalar_t>();
|
||||
const auto A_stride = matrixStride(A_fortran);
|
||||
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
|
||||
|
||||
const int ldvl = 1; // ldvl >= 1 if jobvl = CUSOLVER_EIG_MODE_NOVECTOR
|
||||
cusolverEigMode_t jobvl = CUSOLVER_EIG_MODE_NOVECTOR;
|
||||
|
||||
cusolverEigMode_t jobvr;
|
||||
int ldvr;
|
||||
if (compute_eigenvectors) {
|
||||
ldvr = n; // ldvr >= n if jobvr = CUSOLVER_EIG_MODE_VECTOR
|
||||
jobvr = CUSOLVER_EIG_MODE_VECTOR;
|
||||
}
|
||||
else {
|
||||
ldvr = 1; // ldvr >= 1 if jobvr = CUSOLVER_EIG_MODE_NOVECTOR
|
||||
jobvr = CUSOLVER_EIG_MODE_NOVECTOR;
|
||||
}
|
||||
|
||||
scalar_t* W = values.data_ptr<scalar_t>();
|
||||
scalar_t* VL = nullptr;
|
||||
scalar_t* VR = vectors.data_ptr<scalar_t>();
|
||||
|
||||
const scalar_t* A_const = A_data;
|
||||
const scalar_t* W_const = W;
|
||||
const scalar_t* VL_const = VL;
|
||||
const scalar_t* VR_const = VR;
|
||||
|
||||
size_t ws_dev = 0, ws_host = 0;
|
||||
at::cuda::solver::xgeev_bufferSize<scalar_t>(
|
||||
handle, params,
|
||||
jobvl, jobvr,
|
||||
n,
|
||||
A_const, lda,
|
||||
W_const,
|
||||
VL_const, ldvl,
|
||||
VR_const, ldvr,
|
||||
&ws_dev, &ws_host);
|
||||
|
||||
auto& device_allocator = *at::cuda::getCUDADeviceAllocator();
|
||||
auto work_device_data = device_allocator.allocate(ws_dev);
|
||||
// use pinned memory for best performance.
|
||||
auto& host_allocator = *at::cuda::getPinnedMemoryAllocator();
|
||||
auto work_host_data = host_allocator.allocate(ws_host);
|
||||
|
||||
for (decltype(batch_size) i = 0; i < batch_size; ++i) {
|
||||
scalar_t* Ai = A_data + i * A_stride;
|
||||
scalar_t* Wi = values_data + i * values_stride;
|
||||
scalar_t* VLi = nullptr; // xgeev does not support computing left evs
|
||||
scalar_t* VRi = compute_eigenvectors ? (vectors_data + i * vectors_stride) : nullptr;
|
||||
int* info = infos_data + i;
|
||||
|
||||
at::cuda::solver::xgeev<scalar_t>(
|
||||
handle, params,
|
||||
jobvl, jobvr,
|
||||
n,
|
||||
Ai, lda,
|
||||
Wi,
|
||||
VLi, ldvl,
|
||||
VRi, ldvr,
|
||||
static_cast<scalar_t*>(work_device_data.get()), ws_dev,
|
||||
static_cast<scalar_t*>(work_host_data.get()), ws_host,
|
||||
info);
|
||||
}
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params));
|
||||
}
|
||||
|
||||
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) {
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eig_cuda", [&] {
|
||||
apply_xgeev<scalar_t>(eigenvalues, eigenvectors, input, infos, compute_eigenvectors);
|
||||
});
|
||||
}
|
||||
|
||||
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
|
||||
|
||||
// The 'apply_' word is used for templated by dtype functions that call an API routine
|
||||
// underneath. Since the cusolver API has a slightly different structure we do not prepend
|
||||
// apply_ to this function.
|
||||
|
||||
@ -73,6 +73,11 @@ void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other,
|
||||
Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau);
|
||||
|
||||
void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors);
|
||||
|
||||
void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors);
|
||||
|
||||
|
||||
|
||||
void lu_solve_looped_cusolver(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose);
|
||||
|
||||
void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots);
|
||||
|
||||
@ -1954,6 +1954,336 @@ void xsyevd<c10::complex<double>, double>(
|
||||
workspaceInBytesOnHost,
|
||||
info));
|
||||
}
|
||||
|
||||
// cuSOLVER Xgeev bindings (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+)
|
||||
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
|
||||
|
||||
template <>
|
||||
void xgeev_bufferSize<float>(
|
||||
cusolverDnHandle_t handle,
|
||||
cusolverDnParams_t params,
|
||||
cusolverEigMode_t jobvl,
|
||||
cusolverEigMode_t jobvr,
|
||||
int64_t n,
|
||||
const float* A,
|
||||
int64_t lda,
|
||||
const float* W,
|
||||
const float* VL,
|
||||
int64_t ldvl,
|
||||
const float* VR,
|
||||
int64_t ldvr,
|
||||
size_t* workspaceInBytesOnDevice,
|
||||
size_t* workspaceInBytesOnHost) {
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
|
||||
handle, params, jobvl, jobvr, n,
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<const void*>(A),
|
||||
lda,
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<const void*>(W),
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<const void*>(VL),
|
||||
ldvl,
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<const void*>(VR),
|
||||
ldvr,
|
||||
CUDA_R_32F,
|
||||
workspaceInBytesOnDevice,
|
||||
workspaceInBytesOnHost));
|
||||
}
|
||||
|
||||
template <>
|
||||
void xgeev_bufferSize<double>(
|
||||
cusolverDnHandle_t handle,
|
||||
cusolverDnParams_t params,
|
||||
cusolverEigMode_t jobvl,
|
||||
cusolverEigMode_t jobvr,
|
||||
int64_t n,
|
||||
const double* A,
|
||||
int64_t lda,
|
||||
const double* W,
|
||||
const double* VL,
|
||||
int64_t ldvl,
|
||||
const double* VR,
|
||||
int64_t ldvr,
|
||||
size_t* workspaceInBytesOnDevice,
|
||||
size_t* workspaceInBytesOnHost) {
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
|
||||
handle, params, jobvl, jobvr, n,
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<const void*>(A),
|
||||
lda,
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<const void*>(W),
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<const void*>(VL),
|
||||
ldvl,
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<const void*>(VR),
|
||||
ldvr,
|
||||
CUDA_R_64F,
|
||||
workspaceInBytesOnDevice,
|
||||
workspaceInBytesOnHost));
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
void xgeev_bufferSize<c10::complex<float>>(
|
||||
cusolverDnHandle_t handle,
|
||||
cusolverDnParams_t params,
|
||||
cusolverEigMode_t jobvl,
|
||||
cusolverEigMode_t jobvr,
|
||||
int64_t n,
|
||||
const c10::complex<float>* A,
|
||||
int64_t lda,
|
||||
const c10::complex<float>* W,
|
||||
const c10::complex<float>* VL,
|
||||
int64_t ldvl,
|
||||
const c10::complex<float>* VR,
|
||||
int64_t ldvr,
|
||||
size_t* workspaceInBytesOnDevice,
|
||||
size_t* workspaceInBytesOnHost) {
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
|
||||
handle, params, jobvl, jobvr, n,
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<const void*>(A),
|
||||
lda,
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<const void*>(W),
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<const void*>(VL),
|
||||
ldvl,
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<const void*>(VR),
|
||||
ldvr,
|
||||
CUDA_C_32F,
|
||||
workspaceInBytesOnDevice,
|
||||
workspaceInBytesOnHost));
|
||||
}
|
||||
|
||||
template <>
|
||||
void xgeev_bufferSize<c10::complex<double>>(
|
||||
cusolverDnHandle_t handle,
|
||||
cusolverDnParams_t params,
|
||||
cusolverEigMode_t jobvl,
|
||||
cusolverEigMode_t jobvr,
|
||||
int64_t n,
|
||||
const c10::complex<double>* A,
|
||||
int64_t lda,
|
||||
const c10::complex<double>* W,
|
||||
const c10::complex<double>* VL,
|
||||
int64_t ldvl,
|
||||
const c10::complex<double>* VR,
|
||||
int64_t ldvr,
|
||||
size_t* workspaceInBytesOnDevice,
|
||||
size_t* workspaceInBytesOnHost) {
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize(
|
||||
handle, params, jobvl, jobvr, n,
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<const void*>(A),
|
||||
lda,
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<const void*>(W),
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<const void*>(VL),
|
||||
ldvl,
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<const void*>(VR),
|
||||
ldvr,
|
||||
CUDA_C_64F,
|
||||
workspaceInBytesOnDevice,
|
||||
workspaceInBytesOnHost));
|
||||
}
|
||||
|
||||
template <>
|
||||
void xgeev<float>(
|
||||
cusolverDnHandle_t handle,
|
||||
cusolverDnParams_t params,
|
||||
cusolverEigMode_t jobvl,
|
||||
cusolverEigMode_t jobvr,
|
||||
int64_t n,
|
||||
float* A,
|
||||
int64_t lda,
|
||||
float* W,
|
||||
float* VL,
|
||||
int64_t ldvl,
|
||||
float* VR,
|
||||
int64_t ldvr,
|
||||
float* bufferOnDevice,
|
||||
size_t workspaceInBytesOnDevice,
|
||||
float* bufferOnHost,
|
||||
size_t workspaceInBytesOnHost,
|
||||
int* info) {
|
||||
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
|
||||
handle,
|
||||
params,
|
||||
jobvl,
|
||||
jobvr,
|
||||
n,
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<void*>(A),
|
||||
lda,
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<void*>(W),
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<void*>(VL),
|
||||
ldvl,
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<void*>(VR),
|
||||
ldvr,
|
||||
CUDA_R_32F,
|
||||
reinterpret_cast<void*>(bufferOnDevice),
|
||||
workspaceInBytesOnDevice,
|
||||
reinterpret_cast<void*>(bufferOnHost),
|
||||
workspaceInBytesOnHost,
|
||||
info));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <>
|
||||
void xgeev<double>(
|
||||
cusolverDnHandle_t handle,
|
||||
cusolverDnParams_t params,
|
||||
cusolverEigMode_t jobvl,
|
||||
cusolverEigMode_t jobvr,
|
||||
int64_t n,
|
||||
double* A,
|
||||
int64_t lda,
|
||||
double* W,
|
||||
double* VL,
|
||||
int64_t ldvl,
|
||||
double* VR,
|
||||
int64_t ldvr,
|
||||
double* bufferOnDevice,
|
||||
size_t workspaceInBytesOnDevice,
|
||||
double* bufferOnHost,
|
||||
size_t workspaceInBytesOnHost,
|
||||
int* info) {
|
||||
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
|
||||
handle,
|
||||
params,
|
||||
jobvl,
|
||||
jobvr,
|
||||
n,
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<void*>(A),
|
||||
lda,
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<void*>(W),
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<void*>(VL),
|
||||
ldvl,
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<void*>(VR),
|
||||
ldvr,
|
||||
CUDA_R_64F,
|
||||
reinterpret_cast<void*>(bufferOnDevice),
|
||||
workspaceInBytesOnDevice,
|
||||
reinterpret_cast<void*>(bufferOnHost),
|
||||
workspaceInBytesOnHost,
|
||||
info));
|
||||
|
||||
}
|
||||
|
||||
template <>
|
||||
void xgeev<c10::complex<float>>(
|
||||
cusolverDnHandle_t handle,
|
||||
cusolverDnParams_t params,
|
||||
cusolverEigMode_t jobvl,
|
||||
cusolverEigMode_t jobvr,
|
||||
int64_t n,
|
||||
c10::complex<float>* A,
|
||||
int64_t lda,
|
||||
c10::complex<float>* W,
|
||||
c10::complex<float>* VL,
|
||||
int64_t ldvl,
|
||||
c10::complex<float>* VR,
|
||||
int64_t ldvr,
|
||||
c10::complex<float>* bufferOnDevice,
|
||||
size_t workspaceInBytesOnDevice,
|
||||
c10::complex<float>* bufferOnHost,
|
||||
size_t workspaceInBytesOnHost,
|
||||
int* info) {
|
||||
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
|
||||
handle,
|
||||
params,
|
||||
jobvl,
|
||||
jobvr,
|
||||
n,
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<void*>(A),
|
||||
lda,
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<void*>(W),
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<void*>(VL),
|
||||
ldvl,
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<void*>(VR),
|
||||
ldvr,
|
||||
CUDA_C_32F,
|
||||
reinterpret_cast<void*>(bufferOnDevice),
|
||||
workspaceInBytesOnDevice,
|
||||
reinterpret_cast<void*>(bufferOnHost),
|
||||
workspaceInBytesOnHost,
|
||||
info));
|
||||
}
|
||||
|
||||
template <>
|
||||
void xgeev<c10::complex<double>>(
|
||||
cusolverDnHandle_t handle,
|
||||
cusolverDnParams_t params,
|
||||
cusolverEigMode_t jobvl,
|
||||
cusolverEigMode_t jobvr,
|
||||
int64_t n,
|
||||
c10::complex<double>* A,
|
||||
int64_t lda,
|
||||
c10::complex<double>* W,
|
||||
c10::complex<double>* VL,
|
||||
int64_t ldvl,
|
||||
c10::complex<double>* VR,
|
||||
int64_t ldvr,
|
||||
c10::complex<double>* bufferOnDevice,
|
||||
size_t workspaceInBytesOnDevice,
|
||||
c10::complex<double>* bufferOnHost,
|
||||
size_t workspaceInBytesOnHost,
|
||||
int* info) {
|
||||
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnXgeev(
|
||||
handle,
|
||||
params,
|
||||
jobvl,
|
||||
jobvr,
|
||||
n,
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<void*>(A),
|
||||
lda,
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<void*>(W),
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<void*>(VL),
|
||||
ldvl,
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<void*>(VR),
|
||||
ldvr,
|
||||
CUDA_C_64F,
|
||||
reinterpret_cast<void*>(bufferOnDevice),
|
||||
workspaceInBytesOnDevice,
|
||||
reinterpret_cast<void*>(bufferOnHost),
|
||||
workspaceInBytesOnHost,
|
||||
info));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
|
||||
|
||||
#endif // USE_CUSOLVER_64_BIT
|
||||
|
||||
#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED
|
||||
|
||||
@ -674,6 +674,66 @@ template <>
|
||||
void xsyevd<c10::complex<double>, double>(
|
||||
CUDASOLVER_XSYEVD_ARGTYPES(c10::complex<double>, double));
|
||||
|
||||
|
||||
|
||||
// cuSOLVER Xgeev (non-Hermitian eigen decomposition, CUDA >= 12.8)
|
||||
#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
|
||||
|
||||
#define CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t) \
|
||||
cusolverDnHandle_t handle, cusolverDnParams_t params, \
|
||||
cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, \
|
||||
const scalar_t* A, int64_t lda, const scalar_t* W, \
|
||||
const scalar_t* VL, int64_t ldvl, const scalar_t* VR, int64_t ldvr, \
|
||||
size_t* workspaceInBytesOnDevice, size_t* workspaceInBytesOnHost
|
||||
|
||||
template <class scalar_t>
|
||||
void xgeev_bufferSize(
|
||||
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t)) {
|
||||
static_assert(false&&sizeof(scalar_t),
|
||||
"at::cuda::solver::xgeev_bufferSize: not implemented");
|
||||
}
|
||||
|
||||
template <>
|
||||
void xgeev_bufferSize<float>(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(float));
|
||||
|
||||
template <>
|
||||
void xgeev_bufferSize<double>(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(double));
|
||||
|
||||
template <>
|
||||
void xgeev_bufferSize<c10::complex<float>>(
|
||||
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex<float>));
|
||||
|
||||
template <>
|
||||
void xgeev_bufferSize<c10::complex<double>>(
|
||||
CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex<double>));
|
||||
|
||||
#define CUDASOLVER_XGEEV_ARGTYPES(scalar_t) \
|
||||
cusolverDnHandle_t handle, cusolverDnParams_t params, \
|
||||
cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, scalar_t *A, \
|
||||
int64_t lda, scalar_t *W, scalar_t *VL, int64_t ldvl, scalar_t *VR, int64_t ldvr,\
|
||||
scalar_t *bufferOnDevice, size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost,\
|
||||
size_t workspaceInBytesOnHost, int *info
|
||||
|
||||
template <class scalar_t>
|
||||
void xgeev(CUDASOLVER_XGEEV_ARGTYPES(scalar_t)) {
|
||||
static_assert(false&&sizeof(scalar_t),
|
||||
"at::cuda::solver::xgeev: not implemented");
|
||||
}
|
||||
|
||||
template <>
|
||||
void xgeev<float>(CUDASOLVER_XGEEV_ARGTYPES(float));
|
||||
|
||||
template <>
|
||||
void xgeev<double>(CUDASOLVER_XGEEV_ARGTYPES(double));
|
||||
|
||||
template <>
|
||||
void xgeev<c10::complex<float>>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex<float>));
|
||||
|
||||
template <>
|
||||
void xgeev<c10::complex<double>>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex<double>));
|
||||
|
||||
#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702)
|
||||
|
||||
#endif // USE_CUSOLVER_64_BIT
|
||||
|
||||
#ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED
|
||||
|
||||
@ -119,8 +119,8 @@ void setConvolutionParams(
|
||||
params->input_dim = input.dim();
|
||||
params->memory_format = memory_format;
|
||||
for (int i = 0; i != params->input_dim; ++i) {
|
||||
params->input_size[i] = (int)input.sizes()[i];
|
||||
params->weight_size[i] = (int)weight.sizes()[i];
|
||||
params->input_size[i] = static_cast<int>(input.sizes()[i]);
|
||||
params->weight_size[i] = static_cast<int>(weight.sizes()[i]);
|
||||
}
|
||||
// ASSERT(padding.size() == stride.size())
|
||||
// ASSERT(padding.size() == dilation.size())
|
||||
|
||||
@ -64,7 +64,7 @@
|
||||
// fastest algorithm combination with a sub optimal mathType.
|
||||
|
||||
constexpr size_t operator"" _TiB(unsigned long long n) {
|
||||
return size_t(n) * 1024 * 1024 * 1024 * 1024;
|
||||
return static_cast<size_t>(n) * 1024 * 1024 * 1024 * 1024;
|
||||
}
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -46,7 +46,7 @@ namespace {
|
||||
|
||||
// TODO: remove duplicate code in Conv_v7.cpp
|
||||
constexpr int64_t operator"" _TiB(unsigned long long n) {
|
||||
return size_t(n) << 40;
|
||||
return static_cast<size_t>(n) << 40;
|
||||
}
|
||||
|
||||
uint8_t getAlignment(const Tensor& t) {
|
||||
@ -93,7 +93,10 @@ cudnn_frontend::Tensor getTensorDescriptorWithTypeVirtual(
|
||||
|
||||
std::vector<int64_t> strides_copy(std::begin(strides), std::end(strides));
|
||||
fixSizeOneDimStride<int64_t>(
|
||||
sizes.size(), &sizes[0], (int64_t*)&strides_copy[0], channels_last);
|
||||
sizes.size(),
|
||||
&sizes[0],
|
||||
static_cast<int64_t*>(&strides_copy[0]),
|
||||
channels_last);
|
||||
auto r = cudnn_frontend::TensorBuilder()
|
||||
.setDim(sizes.size(), sizes.data())
|
||||
.setStrides(strides_copy.size(), strides_copy.data())
|
||||
|
||||
@ -44,6 +44,7 @@ std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
|
||||
#include <ATen/cudnn/Descriptors.h>
|
||||
#include <ATen/cudnn/Types.h>
|
||||
#include <ATen/cudnn/Utils.h>
|
||||
#include <array>
|
||||
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <c10/util/irange.h>
|
||||
@ -59,11 +60,11 @@ void setSamplerDescriptor(
|
||||
SpatialTransformerDescriptor& desc,
|
||||
cudnnDataType_t dataType,
|
||||
const at::Tensor& tensor) {
|
||||
int inputSize[4] = {0};
|
||||
std::array<int, 4> inputSize{0};
|
||||
for (const auto i : c10::irange(tensor.dim())) {
|
||||
inputSize[i] = (int)tensor.size(i);
|
||||
inputSize[i] = static_cast<int>(tensor.size(i));
|
||||
}
|
||||
desc.set(dataType, 4, inputSize);
|
||||
desc.set(dataType, 4, inputSize.data());
|
||||
}
|
||||
|
||||
void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input) {
|
||||
|
||||
@ -656,7 +656,8 @@ void add_projection_weights(
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim);
|
||||
auto elem_size = dataSize(getCudnnDataType(weight_buf));
|
||||
auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr();
|
||||
auto offset_bytes = static_cast<const char*>(matrix_pointer) -
|
||||
static_cast<const char*>(weight_buf.data_ptr());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offset_bytes % elem_size == 0,
|
||||
"offset_bytes = ",
|
||||
@ -794,8 +795,8 @@ get_parameters(
|
||||
"; min_dim = ",
|
||||
min_dim);
|
||||
auto elem_size = dataSize(getCudnnDataType(weight_buf));
|
||||
auto offset_bytes =
|
||||
(char*)matrix_pointer - (char*)weight_buf.data_ptr();
|
||||
auto offset_bytes = static_cast<const char*>(matrix_pointer) -
|
||||
static_cast<const char*>(weight_buf.data_ptr());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
offset_bytes % elem_size == 0,
|
||||
"offset_bytes = ",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user