mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Compare commits
163 Commits
fix_fx_gra
...
ciflow/mps
| Author | SHA1 | Date | |
|---|---|---|---|
| cd6bc4df45 | |||
| 836debc10d | |||
| 98d640bb11 | |||
| 5d288bc3f7 | |||
| bfb47ec50e | |||
| 7a0cd8ed09 | |||
| 984e64b2cd | |||
| b9bcb37f40 | |||
| 7e3b9d105e | |||
| 45c3f02d69 | |||
| f5543e3741 | |||
| 5fc2c7a2a1 | |||
| 7692fa09cd | |||
| df71b70727 | |||
| 80ba6e458f | |||
| 0d50e5d8d4 | |||
| 99b05d1b78 | |||
| f911d64750 | |||
| 52db60170d | |||
| 56838bad5f | |||
| ad3a56ab98 | |||
| a7fd0b4001 | |||
| 181ee3bd42 | |||
| 0ec0549823 | |||
| 8221ee6db9 | |||
| b939de26d1 | |||
| 694db5f549 | |||
| 639a0b1239 | |||
| 398775a43e | |||
| fcd5f8c352 | |||
| 4acc66f119 | |||
| 8f40a0c634 | |||
| a5c3c08d10 | |||
| a553ea9ea4 | |||
| ba71e9ca9a | |||
| 694d205143 | |||
| 629293f568 | |||
| c37802a8c4 | |||
| 0a3ac47c0a | |||
| e83be7042e | |||
| fb545fb068 | |||
| 2df2c316e2 | |||
| 08b0a8f11a | |||
| 3f1824742c | |||
| bbb7d2270b | |||
| 6a5a436624 | |||
| ad559072db | |||
| ad02bd13df | |||
| 7563f61cc8 | |||
| fa8e073a4e | |||
| 95b5534773 | |||
| 9ee1afbf66 | |||
| f60751024e | |||
| 2de4cf2102 | |||
| 369f2d6951 | |||
| 32920926f0 | |||
| 39e5cdddf7 | |||
| 2829d48bd1 | |||
| f1af679270 | |||
| d46d8d6f54 | |||
| a5335263d3 | |||
| 79aee77381 | |||
| f5cb9a4c68 | |||
| f20bf77874 | |||
| 75f798e05b | |||
| 476b149a00 | |||
| 845da9c817 | |||
| 0918bf321c | |||
| 90519402c2 | |||
| 791ca80d3a | |||
| 5cbdade914 | |||
| 0187db88d4 | |||
| 311ea0dec0 | |||
| cf7756da38 | |||
| e380028a51 | |||
| b4403bfc62 | |||
| 12c12466b0 | |||
| f4d05feb7a | |||
| 7481622237 | |||
| b2a0f90501 | |||
| 14d4a77495 | |||
| 3d4ca228be | |||
| 6c476d7dd6 | |||
| 8e9b0409bd | |||
| 2ee56e1f3b | |||
| bbbbc14698 | |||
| f7a13a6dfc | |||
| 98aff4e90e | |||
| 6cdf27661c | |||
| ecd4542830 | |||
| d0b7578e17 | |||
| e22a5ecb45 | |||
| 9c8b449159 | |||
| 613b0adb13 | |||
| 12534b44b3 | |||
| 831ad1f70b | |||
| 1eed70b417 | |||
| ef230fcd2d | |||
| c283224b77 | |||
| 9bd3d28afa | |||
| 1c2ad5fe9d | |||
| 5c57109190 | |||
| e3b463d216 | |||
| e6bab78b38 | |||
| b9f7dd6e77 | |||
| 2b6c6e3d64 | |||
| e289d12b73 | |||
| bb2c89fc3b | |||
| d530a21122 | |||
| 27fb875a70 | |||
| a5ecc01ef8 | |||
| f7b574c862 | |||
| 5d6924f0a4 | |||
| cf7a543013 | |||
| 55043b3ada | |||
| 7b61d461ab | |||
| 0ff1ddecad | |||
| 0f9b4aae53 | |||
| d6886407ef | |||
| c37103c16d | |||
| d7ecbf7243 | |||
| 8080febac5 | |||
| 408ab373f8 | |||
| 23b1bbb810 | |||
| 8268316590 | |||
| a56e4e5fc7 | |||
| dac4caf77b | |||
| 1710c27341 | |||
| 92fde5bdc5 | |||
| eafc6437d6 | |||
| 4a90ec1387 | |||
| d9ab4e3ade | |||
| ce372a06e9 | |||
| 67f0059726 | |||
| 55473096e2 | |||
| 2a15541b5d | |||
| 16875b228b | |||
| 808af988ae | |||
| 61c298ed56 | |||
| 2d38369728 | |||
| 74bd12415e | |||
| a8c0c0263c | |||
| e3f14cdafa | |||
| b03cc2d9c8 | |||
| 99917e659b | |||
| def4476b6b | |||
| 84bb803719 | |||
| 8266849bda | |||
| f766c8ceea | |||
| b3cf7bc86d | |||
| 6a4a8b453d | |||
| cab14368a7 | |||
| 0b87f606ca | |||
| 25f68b6b5b | |||
| 334489bfd0 | |||
| ada5d90c25 | |||
| a1e3e2026b | |||
| 8dc4932b0c | |||
| 1505d7a461 | |||
| 1c842a9686 | |||
| 692827ad29 | |||
| 786d4646ee | |||
| 70df0aed59 |
@ -40,11 +40,7 @@ EOF
|
||||
|
||||
# Default url values
|
||||
rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}"
|
||||
amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu"
|
||||
|
||||
# Add amdgpu repository
|
||||
UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'`
|
||||
echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list
|
||||
|
||||
# Add rocm repository
|
||||
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -
|
||||
|
||||
@ -6,7 +6,7 @@ dependencies = [
|
||||
"GitPython==3.1.45",
|
||||
"docker==7.1.0",
|
||||
"pytest==7.3.2",
|
||||
"uv==0.9.5"
|
||||
"uv==0.9.6"
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
||||
4
.github/actions/diskspace-cleanup/action.yml
vendored
4
.github/actions/diskspace-cleanup/action.yml
vendored
@ -27,7 +27,9 @@ runs:
|
||||
docker system prune -af
|
||||
diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
|
||||
if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then
|
||||
echo "Error: Available diskspace is less than $diskspace_cutoff percent. Not enough diskspace."
|
||||
diskspace_cutoff_int=$((diskspace_cutoff + 0))
|
||||
difference=$((100 - diskspace_cutoff_int))
|
||||
echo "Error: Available diskspace is less than $difference percent. Not enough diskspace."
|
||||
echo "$msg"
|
||||
exit 1
|
||||
else
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
69bbe7363897764f9e758d851cd0340147d27f94
|
||||
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2
|
||||
|
||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -26,6 +26,7 @@ ciflow_push_tags:
|
||||
- ciflow/nightly
|
||||
- ciflow/op-benchmark
|
||||
- ciflow/periodic
|
||||
- ciflow/periodic-rocm-mi200
|
||||
- ciflow/periodic-rocm-mi300
|
||||
- ciflow/pull
|
||||
- ciflow/quantization-periodic
|
||||
|
||||
84
.github/workflows/periodic-rocm-mi200.yml
vendored
Normal file
84
.github/workflows/periodic-rocm-mi200.yml
vendored
Normal file
@ -0,0 +1,84 @@
|
||||
name: periodic-rocm-mi200
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
|
||||
# Also run less frequently on weekends.
|
||||
- cron: 45 0,8,16 * * 1-5
|
||||
- cron: 45 4 * * 0,6
|
||||
- cron: 45 4,12,20 * * 1-5
|
||||
- cron: 45 12 * * 0,6
|
||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||
push:
|
||||
tags:
|
||||
- ciflow/periodic/*
|
||||
- ciflow/periodic-rocm-mi200/*
|
||||
branches:
|
||||
- release/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
llm-td:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: before-test
|
||||
uses: ./.github/workflows/llm_td_retrieval.yml
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
target-determination:
|
||||
name: before-test
|
||||
uses: ./.github/workflows/target_determination.yml
|
||||
needs: llm-td
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch'
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
31
.github/workflows/periodic.yml
vendored
31
.github/workflows/periodic.yml
vendored
@ -204,37 +204,6 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build:
|
||||
name: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
1
.github/workflows/upload-test-stats.yml
vendored
1
.github/workflows/upload-test-stats.yml
vendored
@ -6,6 +6,7 @@ on:
|
||||
- pull
|
||||
- trunk
|
||||
- periodic
|
||||
- periodic-rocm-mi200
|
||||
- periodic-rocm-mi300
|
||||
- inductor
|
||||
- unstable
|
||||
|
||||
@ -1198,12 +1198,6 @@ exclude_patterns = [
|
||||
'torch/_inductor/fx_passes/serialized_patterns/**',
|
||||
'torch/_inductor/autoheuristic/artifacts/**',
|
||||
'torch/utils/model_dump/preact.mjs',
|
||||
# These files are all grandfathered in, feel free to remove from this list
|
||||
# as necessary
|
||||
# NOTE: remove the patterns in the order they are listed
|
||||
'aten/src/ATen/native/[a-pA-P]*/**',
|
||||
'aten/src/ATen/[a-mA-M]*/**',
|
||||
'test/**',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
|
||||
@ -374,7 +374,7 @@ cmake_dependent_option(
|
||||
"Build the lazy Torchscript backend, not compatible with mobile builds" ON
|
||||
"NOT INTERN_BUILD_MOBILE" OFF)
|
||||
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
|
||||
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
|
||||
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin folder"
|
||||
OFF "USE_CUDA" OFF)
|
||||
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
|
||||
"CPU_AARCH64" OFF)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
// Re-declaring 'DimVector' type and size inside 'at' namespace.
|
||||
// Redeclaring 'DimVector' type and size inside 'at' namespace.
|
||||
// This is done to avoid modifying every use into their 'c10'
|
||||
// equivalent.
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ _GeneratorRegister::_GeneratorRegister(const GeneratorFuncType& func) {
|
||||
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"REGISTER_GENERATOR_PRIVATEUSE1 is deprecated. \
|
||||
Please derive PrivateUse1HooksInterface to implememt getNewGenerator instead.")
|
||||
Please derive PrivateUse1HooksInterface to implement getNewGenerator instead.")
|
||||
|
||||
TORCH_CHECK(
|
||||
!GetGeneratorPrivate().has_value(),
|
||||
|
||||
@ -149,7 +149,7 @@
|
||||
* First, keep in mind that we assume that boxed containers will
|
||||
* have to deal with `IValue` (e.g. `c10::List`). In this context,
|
||||
* what may be happening is that `IValue` doesn't store internally
|
||||
* your type `T`. Instead, it constructs a type new `T` everytime
|
||||
* your type `T`. Instead, it constructs a type new `T` every time
|
||||
* you try to get `T` for it (see `IListRef<at::OptinalTensorRef>`).
|
||||
*/
|
||||
|
||||
@ -186,7 +186,7 @@ class IListRef;
|
||||
* This macro is useful because it allows us to handle different
|
||||
* types (that correspond to different tags) to be implemented
|
||||
* only once. We can do it even when the implementation of the
|
||||
* different tags aren't syntatically the same, by dispatching
|
||||
* different tags aren't syntactically the same, by dispatching
|
||||
* it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
|
||||
*/
|
||||
#define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
|
||||
|
||||
@ -42,7 +42,7 @@ class IListRefTagImplBase<IListRefTag::Unboxed, T, ListElemT> {
|
||||
/*
|
||||
* We have these function (besides the `unwrap`s above) because the
|
||||
* implementation for both `IListRef::operator[]` and `IListRefIterator::operator*`
|
||||
* weren't syntatically equal for the existing tags at the time
|
||||
* weren't syntactically equal for the existing tags at the time
|
||||
* (`Unboxed` and `Boxed`).
|
||||
*/
|
||||
static IListRefConstRef<T> front(const list_type& lst) {
|
||||
|
||||
@ -12,7 +12,7 @@ namespace at {
|
||||
// in order. This is most commonly used in autogenerated code,
|
||||
// where it is convenient to have a function that can uniformly
|
||||
// take arguments of different types. If your arguments
|
||||
// are homogenous consider using a std::initializer_list instead.
|
||||
// are homogeneous consider using a std::initializer_list instead.
|
||||
//
|
||||
// For examples of this in use, see torch/csrc/utils/variadic.h
|
||||
template <typename F>
|
||||
|
||||
@ -111,7 +111,7 @@ void Dispatcher::waitForDef(const FunctionSchema& schema) {
|
||||
TORCH_INTERNAL_ASSERT(r,
|
||||
"Expected main interpreter to define ", schema.operator_name(),
|
||||
", but this didn't happen within timeout. Are you trying to load "
|
||||
"different models in the same torchdeploy/multipy instance? You "
|
||||
"different models in the same torchdeploy/multipy instance? You " // codespell:ignore
|
||||
"must warmup each interpreter identically, e.g., import all "
|
||||
"the same dependencies.");
|
||||
}
|
||||
@ -129,7 +129,7 @@ void Dispatcher::waitForImpl(const OperatorName& op_name, std::optional<c10::Dis
|
||||
TORCH_INTERNAL_ASSERT(r,
|
||||
"Expected main interpreter to implement ", dk, " for ", op_name,
|
||||
", but this didn't happen within timeout. Are you trying to load "
|
||||
"different models in the same torchdeploy/multipy instance? You "
|
||||
"different models in the same torchdeploy/multipy instance? You " // codespell:ignore
|
||||
"must warmup each interpreter identically, e.g., import all "
|
||||
"the same dependencies.");
|
||||
}
|
||||
@ -442,8 +442,8 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
||||
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
||||
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
||||
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
|
||||
// NB: Preserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
||||
// refer to https://github.com/pytorch/pytorch/issues/163979 for more information.
|
||||
TORCH_CHECK(
|
||||
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
|
||||
@ -222,7 +222,8 @@ class TORCH_API Dispatcher final {
|
||||
return backendFallbackKernels_[dispatch_ix].kernel.isValid();
|
||||
}
|
||||
|
||||
// Used by torchdeploy/multipy for multiple interpreters racing.
|
||||
// Used by torchdeploy/multipy for multiple // codespell:ignore: multipy
|
||||
// interpreters racing.
|
||||
void waitForDef(const FunctionSchema& schema);
|
||||
void waitForImpl(
|
||||
const OperatorName& op_name,
|
||||
@ -414,7 +415,7 @@ class TORCH_API Dispatcher final {
|
||||
std::unique_ptr<detail::RegistrationListenerList> listeners_;
|
||||
|
||||
// This condition variable gets notified whenever we add a new def/impl to the
|
||||
// dispatch table. This is primarily used by multipy/torchdeploy, when
|
||||
// dispatch table. This is primarily used by multiply/torchdeploy, when
|
||||
// we have multiple interpreters trying to register to the dispatch table.
|
||||
// In this situation, whenever the non-primary interpreter would have tried
|
||||
// to register to the dispatch table, instead it will check to see if the
|
||||
|
||||
@ -990,7 +990,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
if (completed_) {
|
||||
// This should be rare and shouldn't cause log spew. Its important to
|
||||
// log errors and thats why we have this log here.
|
||||
// log errors and that's why we have this log here.
|
||||
std::string msg = c10::str(
|
||||
"Skipping setting following error on the Future since "
|
||||
"it is already marked completed (this is not necessarily "
|
||||
|
||||
@ -887,7 +887,7 @@ struct TORCH_API ListType
|
||||
// this function will return the global singleton type pointer
|
||||
// the type List<T>.
|
||||
// The extra "identifier" argument is needed because we have multiple container types
|
||||
// that all re-use this function (List<T>, array<T, N>, etc.)
|
||||
// that all reuse this function (List<T>, array<T, N>, etc.)
|
||||
static TypePtr get(const std::string& identifier, TypePtr inner);
|
||||
|
||||
// common cast List[Tensor]
|
||||
@ -983,7 +983,7 @@ struct TORCH_API DictType : public SharedType {
|
||||
// this function will return the global singleton type pointer
|
||||
// the type List<T>.
|
||||
// The extra "identifier" argument is needed because we have multiple container types
|
||||
// that all re-use this function (Dict<K, V> and unordered_map<K, V>)
|
||||
// that all reuse this function (Dict<K, V> and unordered_map<K, V>)
|
||||
static TypePtr get(const std::string& identifier, TypePtr key, TypePtr val);
|
||||
|
||||
private:
|
||||
|
||||
@ -680,7 +680,7 @@ TORCH_API bool elementTypeCanBeInferredFromMembers(const TypePtr& elem_type) {
|
||||
return false;
|
||||
}
|
||||
if (elem_type->kind() == AnyType::Kind) {
|
||||
// List of Any can contains heterogenous types
|
||||
// List of Any can contains heterogeneous types
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
@ -309,7 +309,7 @@ class Vectorized<float> {
|
||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
||||
// Implementation copied from Arm Optimized Routine
|
||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
||||
Vectorized<float> exp_u20() const {
|
||||
inline Vectorized<float> vexpq_f32_u20() const {
|
||||
// bail out to sleef if it's a special case:
|
||||
// i.e. there's an input s.t. |input| > 87.3....
|
||||
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
||||
@ -348,6 +348,9 @@ class Vectorized<float> {
|
||||
|
||||
return vfmaq_f32(scale, poly, scale);
|
||||
}
|
||||
Vectorized<float> exp_u20() const {
|
||||
return vexpq_f32_u20();
|
||||
}
|
||||
Vectorized<float> fexp_u20() const {
|
||||
return exp_u20();
|
||||
}
|
||||
@ -634,7 +637,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
|
||||
// - exp(- x * x)
|
||||
auto pow_2 = (*this) * (*this);
|
||||
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
|
||||
auto tmp4 = neg_pow_2.exp();
|
||||
auto tmp4 = neg_pow_2.vexpq_f32_u20();
|
||||
auto tmp5 = tmp4 ^ neg_zero_vec;
|
||||
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
|
||||
auto tmp6 = t * tmp5;
|
||||
|
||||
@ -498,8 +498,8 @@ static inline Vectorized<T> binary_fp8_op_as_fp32(
|
||||
|
||||
// Refer to
|
||||
// https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +,
|
||||
// -, *, /, planed to be deleted in the future and here is just to make compiler
|
||||
// happy
|
||||
// -, *, /, planned to be deleted in the future and here is just to make
|
||||
// compiler happy
|
||||
Vectorized<Float8_e4m3fn> inline operator+(
|
||||
const Vectorized<Float8_e4m3fn>& a,
|
||||
const Vectorized<Float8_e4m3fn>& b) {
|
||||
@ -585,8 +585,8 @@ class Vectorized<Float8_e5m2> : public Vectorizedf8<Float8_e5m2> {
|
||||
|
||||
// Refer to
|
||||
// https://github.com/pytorch/pytorch/pull/153364#discussion_r2086509353 FP8 +,
|
||||
// -, *, /, planed to be deleted in the future and here is just to make compiler
|
||||
// happy
|
||||
// -, *, /, planned to be deleted in the future and here is just to make
|
||||
// compiler happy
|
||||
Vectorized<Float8_e5m2> inline operator+(
|
||||
const Vectorized<Float8_e5m2>& a,
|
||||
const Vectorized<Float8_e5m2>& b) {
|
||||
|
||||
@ -1,78 +1,90 @@
|
||||
#include <ATen/cuda/CUDAGreenContext.h>
|
||||
|
||||
namespace at::cuda {
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define HAS_CUDA_GREEN_CONTEXT() 1
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#define HAS_CUDA_GREEN_CONTEXT() 0
|
||||
// Suppress unused private field warnings as this class is not supposed to be called
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field")
|
||||
#endif
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
int driver_version;
|
||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||
TORCH_CHECK(
|
||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||
CUcontext pctx = nullptr;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
TORCH_WARN(
|
||||
"Attempted to create a green context but"
|
||||
" there was no primary context! Creating a primary context...");
|
||||
|
||||
cudaFree(0);
|
||||
}
|
||||
|
||||
CUdevice device;
|
||||
device_id_ = device_id;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||
|
||||
// Get device resources
|
||||
CUdevResource device_resource;
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||
|
||||
// Split resources
|
||||
std::vector<CUdevResource> result(1);
|
||||
auto result_data = result.data();
|
||||
unsigned int nb_groups = 1;
|
||||
CUdevResource remaining;
|
||||
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||
result_data,
|
||||
&nb_groups,
|
||||
&device_resource,
|
||||
&remaining,
|
||||
0, // default flags
|
||||
num_sms));
|
||||
|
||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||
|
||||
// Generate resource descriptor
|
||||
CUdevResourceDesc desc;
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||
&desc, result_data, 1));
|
||||
|
||||
// Create green context
|
||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||
|
||||
// Convert to regular context
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<GreenContext> GreenContext::create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id) {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
if (!device_id.has_value()) {
|
||||
device_id = at::cuda::current_device();
|
||||
}
|
||||
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
@ -80,7 +92,7 @@ namespace at::cuda {
|
||||
|
||||
// Implement move operations
|
||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
device_id_ = std::exchange(other.device_id_, -1);
|
||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||
context_ = std::exchange(other.context_, nullptr);
|
||||
@ -91,7 +103,7 @@ namespace at::cuda {
|
||||
}
|
||||
|
||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
if (this != &other) {
|
||||
// Clean up current resources
|
||||
if (green_ctx_) {
|
||||
@ -120,7 +132,7 @@ namespace at::cuda {
|
||||
}
|
||||
|
||||
GreenContext::~GreenContext() noexcept{
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
C10_CUDA_DRIVER_CHECK(
|
||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||
#else
|
||||
@ -128,25 +140,9 @@ namespace at::cuda {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext GreenContext::getContext() const {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
return context_;
|
||||
#else
|
||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx GreenContext::getGreenContext() const {
|
||||
return green_ctx_;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void GreenContext::setContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||
parent_stream_ = current_stream.stream();
|
||||
|
||||
@ -175,7 +171,7 @@ namespace at::cuda {
|
||||
}
|
||||
|
||||
void GreenContext::popContext() {
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
#if HAS_CUDA_GREEN_CONTEXT()
|
||||
// see above note about stream being hardcoded to the default stream
|
||||
at::cuda::CUDAEvent ev;
|
||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||
|
||||
@ -1,53 +1,38 @@
|
||||
#pragma once
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
|
||||
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||
#include <c10/cuda/driver_api.h>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||
#else
|
||||
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||
#endif
|
||||
|
||||
// Forward declare green context as opaque ptr
|
||||
typedef struct CUgreenCtx_st* CUgreenCtx;
|
||||
|
||||
namespace at::cuda {
|
||||
|
||||
class TORCH_CUDA_CPP_API GreenContext {
|
||||
public:
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
|
||||
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||
// Green context creation
|
||||
static std::unique_ptr<GreenContext> create(
|
||||
uint32_t num_sms,
|
||||
std::optional<uint32_t> device_id);
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Delete copy constructor and assignment
|
||||
GreenContext(const GreenContext&) = delete;
|
||||
GreenContext& operator=(const GreenContext&) = delete;
|
||||
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
~GreenContext() noexcept;
|
||||
|
||||
// Get the underlying CUDA context
|
||||
CUcontext getContext() const;
|
||||
|
||||
// Get the underlying green context
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
CUgreenCtx getGreenContext() const;
|
||||
#endif
|
||||
|
||||
// Make this context current
|
||||
void setContext();
|
||||
|
||||
void popContext();
|
||||
|
||||
private:
|
||||
#if CUDA_HAS_GREEN_CONTEXT
|
||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||
// Implement move operations
|
||||
GreenContext(GreenContext&& other) noexcept;
|
||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||
|
||||
int32_t device_id_ = -1;
|
||||
CUgreenCtx green_ctx_ = nullptr;
|
||||
CUcontext context_ = nullptr;
|
||||
cudaStream_t parent_stream_ = nullptr;
|
||||
#endif
|
||||
};
|
||||
} // namespace at::cuda
|
||||
|
||||
@ -7,17 +7,6 @@
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
// hipSparse const API added in v2.4.0
|
||||
#if HIPSPARSE_VERSION >= 200400
|
||||
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
||||
#else
|
||||
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
||||
#endif
|
||||
#else // USE_ROCM
|
||||
#define AT_USE_HIPSPARSE_GENERIC_API() 0
|
||||
#endif // USE_ROCM
|
||||
|
||||
// cuSparse Generic API spsv function was added in CUDA 11.3.0
|
||||
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
|
||||
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
|
||||
|
||||
@ -179,7 +179,7 @@ CuSparseSpMatCsrDescriptor::CuSparseSpMatCsrDescriptor(const Tensor& input, int6
|
||||
batch_offset * values_batch_stride * values.itemsize(),
|
||||
index_type, // data type of row offsets index
|
||||
index_type, // data type of col indices
|
||||
CUSPARSE_INDEX_BASE_ZERO, // base index of row offset and col indes
|
||||
CUSPARSE_INDEX_BASE_ZERO, // base index of row offset and col index
|
||||
value_type // data type of values
|
||||
));
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ namespace at::cuda {
|
||||
//
|
||||
// A caching allocator for CUDA host allocations (pinned memory).
|
||||
//
|
||||
// This provides a drop-in replacement for THCudaHostAllocator, which re-uses
|
||||
// This provides a drop-in replacement for THCudaHostAllocator, which reuses
|
||||
// freed pinned (page-locked) memory allocations. This avoids device
|
||||
// synchronizations due to cudaFreeHost calls.
|
||||
//
|
||||
@ -26,7 +26,7 @@ inline TORCH_CUDA_CPP_API at::HostAllocator* getCachingHostAllocator() {
|
||||
}
|
||||
|
||||
// Records an event in the specified stream. The allocation corresponding to the
|
||||
// input `ptr`/`ctx` will not be re-used until the event has occurred.
|
||||
// input `ptr`/`ctx` will not be reused until the event has occurred.
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"at::cuda::CachingHostAllocator_recordEvent(...) is deprecated. Please use at::getHostAllocator(at::kCUDA)->record_event(...) instead.")
|
||||
inline TORCH_CUDA_CPP_API bool CachingHostAllocator_recordEvent(
|
||||
|
||||
@ -93,7 +93,7 @@ struct IndexToOffset {
|
||||
}
|
||||
};
|
||||
|
||||
// Uses dynamic (runtime) instead of static (compiletime) dims
|
||||
// Uses dynamic (runtime) instead of static (compile time) dims
|
||||
template <typename T, typename IndexType>
|
||||
struct IndexToOffset<T, IndexType, -1> {
|
||||
static inline __host__ __device__ IndexType get(
|
||||
|
||||
@ -32,7 +32,7 @@ static inline void launch_jitted_vectorized_kernel_dynamic(
|
||||
|
||||
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
|
||||
// fn_ptr is set to the appropriate function based on the vec size and GPU used
|
||||
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
|
||||
// TODO: Memory use can probably be optimized by reusing kernels across GPUs with
|
||||
// the same compute capability
|
||||
|
||||
std::string f_inputs_type_str = at::cuda::jit::typeName(common_dtype);
|
||||
|
||||
@ -143,7 +143,7 @@ struct TORCH_API VmapPhysicalView {
|
||||
// mapping a physical tensor to a new logical tensor (BatchedTensor)
|
||||
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
|
||||
|
||||
// Maps a logical shape to a physical shape by pre-pending the batch
|
||||
// Maps a logical shape to a physical shape by prepending the batch
|
||||
// sizes to the logical shape.
|
||||
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
|
||||
SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const;
|
||||
|
||||
@ -27,7 +27,7 @@ namespace at::functorch {
|
||||
//
|
||||
// There are alternative designs we could have chosen (e.g. each grad transform
|
||||
// stores a weak map of Tensor -> AutogradMeta); the benefit of the TensorWrapper
|
||||
// design is that we can re-use existing VariableType kernels (i.e. Autograd kernels)
|
||||
// design is that we can reuse existing VariableType kernels (i.e. Autograd kernels)
|
||||
// without much modification. Since a TensorWrapper looks like a regular Tensor,
|
||||
// the VariableType kernel can pull out the AutogradMeta struct from where it
|
||||
// expects and extend the autograd graph
|
||||
|
||||
@ -410,8 +410,8 @@ struct ConvParams {
|
||||
return false;
|
||||
}
|
||||
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
|
||||
// broken on cuDNN 9.8
|
||||
if (cudnn_version >= 90800) {
|
||||
// broken on cuDNN 9.8 - 9.14
|
||||
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
||||
(input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) &&
|
||||
weight.dim() == 5) {
|
||||
|
||||
@ -1017,7 +1017,7 @@ struct HelperInterpBase {
|
||||
while (aligned_interp_size % sizeof(int32_t) != 0) {
|
||||
aligned_interp_size += 1;
|
||||
}
|
||||
// assert that we wont go out of bounds
|
||||
// assert that we won't go out of bounds
|
||||
TORCH_INTERNAL_ASSERT(aligned_interp_size * sizeof(int16_t) < interp_size * sizeof(double));
|
||||
}
|
||||
|
||||
|
||||
@ -655,7 +655,7 @@ void ImagingResampleHorizontalConvolution8u4x(
|
||||
// last element
|
||||
auto mmk = _mm256_set1_epi32(k[i]);
|
||||
// For num_channels == 3 (3 bytes = one pixel) we tolerate to read 4 bytes
|
||||
// lines 0, 1 and 2 wont go out of allocated memory bounds
|
||||
// lines 0, 1 and 2 won't go out of allocated memory bounds
|
||||
auto pix = _mm256_inserti128_si256(_mm256_castsi128_si256(
|
||||
mm_cvtepu8_epi32(lineIn0_min + stride * i, i32_aligned)),
|
||||
mm_cvtepu8_epi32(lineIn1_min + stride * i, i32_aligned), 1);
|
||||
@ -1312,7 +1312,7 @@ void ImagingResampleVerticalConvolution8u(
|
||||
|
||||
// Here we write 4 bytes to the output even if num_channels < 4, e.g o = {r,g,b,X} for num_channels=3
|
||||
// It is OK to write 4th byte (e.g. X) as on the next step we will overwrite it with new data.
|
||||
// We also wont go out of bounds of lineOut memory allocation
|
||||
// We also won't go out of bounds of lineOut memory allocation
|
||||
std::memcpy(lineOut + j, (uint8_t *) &o, 4);
|
||||
}
|
||||
|
||||
|
||||
@ -705,7 +705,7 @@ namespace {
|
||||
);
|
||||
} while (!done && max_threads);
|
||||
if (!done) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Couldn't reduce launch bounds to accomodate sharedMemPerBlock limit");
|
||||
TORCH_INTERNAL_ASSERT(false, "Couldn't reduce launch bounds to accommodate sharedMemPerBlock limit");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@ -298,7 +298,7 @@ static void jitted_gpu_kernel_impl(
|
||||
at::opmath_type<f_inputs_type> scalar_val,
|
||||
const std::tuple<ExtraArgs...>& extra_args) {
|
||||
|
||||
// TODO: Memory use can probably be optimized by re-using kernels across GPUs with
|
||||
// TODO: Memory use can probably be optimized by reusing kernels across GPUs with
|
||||
// the same compute capability
|
||||
static std::mutex jiterator_mutex;
|
||||
static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
|
||||
|
||||
@ -75,7 +75,7 @@ fused_dropout_kernel_vec(at::cuda::detail::TensorInfo<const scalar_t, IndexType>
|
||||
// We'll use this to actually cause vectorized loads later
|
||||
LoadT *value = reinterpret_cast<LoadT*>(&src);
|
||||
|
||||
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
|
||||
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for Halfs, so generate float for everything
|
||||
// Note: need a new set of random values per 4 elements -- we'll handle VEC elements in this thread, so need ceil(VEC / 4)
|
||||
// sets of rand.
|
||||
if ((VEC >= 4) || (gridxvec_loop_state == 0)) {
|
||||
@ -159,7 +159,7 @@ fused_dropout_kernel(cuda::detail::TensorInfo<const scalar_t, IndexType> a,
|
||||
for (IndexType linearIndex = idx;
|
||||
linearIndex < rounded_size;
|
||||
linearIndex += gridDim.x * blockDim.x*UNROLL) {
|
||||
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for halfs, so generate float for everything
|
||||
//curand_uniform_double was pure evil anyway, not doing what it promises, and there's nothing for Halfs, so generate float for everything
|
||||
float4 rand = curand_uniform4(&state);
|
||||
scalar_t src[UNROLL];
|
||||
rand.x = rand.x < p;
|
||||
|
||||
@ -24,7 +24,7 @@ namespace at::native {
|
||||
namespace {
|
||||
|
||||
/* This code computes the sum of the weights in two-steps:
|
||||
1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indeces`
|
||||
1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indices`
|
||||
2) Each partial-sum from 1) are summed and scatter into `grad_weight`
|
||||
|
||||
Notice, `NROWS_PER_THREAD` impacts the Achieved Occupancy of the
|
||||
|
||||
@ -204,7 +204,7 @@ Scalar scalar_reciprocal(const Scalar& scalar) {
|
||||
return Scalar(1. / scalar.toComplexDouble());
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
false, "divison with ", scalar.type(), " not supported");
|
||||
false, "division with ", scalar.type(), " not supported");
|
||||
}
|
||||
|
||||
void foreach_tensor_div_scalar_kernel_cuda_(
|
||||
|
||||
@ -57,7 +57,7 @@ namespace {
|
||||
const index_t n = index / (out_H * out_W);
|
||||
const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y co-ordinates from grid
|
||||
// get the corresponding input x, y coordinates from grid
|
||||
opmath_t x = grid.data[grid_offset];
|
||||
opmath_t y = grid.data[grid_offset + grid_sCoor];
|
||||
|
||||
@ -193,7 +193,7 @@ namespace {
|
||||
const index_t n = index / (out_D * out_H * out_W);
|
||||
const index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y, z co-ordinates from grid
|
||||
// get the corresponding input x, y, z coordinates from grid
|
||||
opmath_t x = grid.data[grid_offset];
|
||||
opmath_t y = grid.data[grid_offset + grid_sCoor];
|
||||
opmath_t z = grid.data[grid_offset + 2 * grid_sCoor];
|
||||
@ -358,7 +358,7 @@ namespace {
|
||||
const index_t n = index / (out_H * out_W);
|
||||
const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y co-ordinates from grid
|
||||
// get the corresponding input x, y coordinates from grid
|
||||
scalar_t x = grid.data[grid_offset];
|
||||
scalar_t y = grid.data[grid_offset + grid_sCoor];
|
||||
|
||||
@ -572,7 +572,7 @@ namespace {
|
||||
const index_t n = index / (out_D * out_H * out_W);
|
||||
const auto grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y, z co-ordinates from grid
|
||||
// get the corresponding input x, y, z coordinates from grid
|
||||
scalar_t ix = grid.data[grid_offset];
|
||||
scalar_t iy = grid.data[grid_offset + grid_sCoor];
|
||||
scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
|
||||
|
||||
@ -8,7 +8,7 @@
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
|
||||
// Three warninngs in Cutlass included header files
|
||||
// Three warnings in Cutlass included header files
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
|
||||
|
||||
@ -213,9 +213,9 @@ _f4_f4_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const Tensor& global_scale_a,
|
||||
const std::optional<Tensor>& global_scale_a,
|
||||
const Tensor& scale_b,
|
||||
const Tensor& global_scale_b,
|
||||
const std::optional<Tensor>& global_scale_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
Tensor& out) {
|
||||
@ -225,14 +225,28 @@ _f4_f4_bf16_grouped_mm_fbgemm(
|
||||
"mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2,
|
||||
"mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat,
|
||||
"global_scale_a must be Float, got: ", global_scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat,
|
||||
"global_scale_b must be Float, got: ", global_scale_b.scalar_type());
|
||||
|
||||
std::optional<Tensor> combined_global_scale = std::nullopt;
|
||||
if (global_scale_a.has_value() || global_scale_b.has_value()) {
|
||||
// NVFP4
|
||||
TORCH_CHECK_VALUE(global_scale_a.has_value() && global_scale_b.has_value(),
|
||||
"For NVFP4 grouped gemm both of global_scale_{a,b} must have values")
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_a.value().scalar_type() == at::kFloat,
|
||||
"global_scale_a must be Float, got: ", global_scale_a.value().scalar_type());
|
||||
TORCH_CHECK_VALUE(global_scale_b.value().scalar_type() == at::kFloat,
|
||||
"global_scale_b must be Float, got: ", global_scale_b.value().scalar_type());
|
||||
combined_global_scale = global_scale_a.value().mul(global_scale_b.value());
|
||||
} else {
|
||||
// MXFP4
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"scale_a must be Float8_e8m0fnu, got: ", scale_a.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"scale_b must be Float8_e8m0fnu, got: ", scale_b.scalar_type());
|
||||
}
|
||||
|
||||
auto o = fbgemm_gpu::f4f4bf16_grouped_mm(
|
||||
mat_a,
|
||||
@ -241,7 +255,7 @@ _f4_f4_bf16_grouped_mm_fbgemm(
|
||||
scale_b,
|
||||
offs.value(),
|
||||
out,
|
||||
global_scale_a.mul(global_scale_b)
|
||||
combined_global_scale
|
||||
);
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA")
|
||||
@ -471,9 +485,10 @@ namespace {
|
||||
|
||||
using acceptance_fn = std::function<bool(c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&, c10::ScalarType, std::vector<ScalingType>&, ArrayRef<Tensor>&)>;
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 3> scale_grouped_kernel_dispatch = {{
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 4> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8},
|
||||
{ "mxfp4_mxfp4", scaled_blas::check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4},
|
||||
{ "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}};
|
||||
|
||||
} // anonymous namespace
|
||||
@ -510,7 +525,7 @@ _scaled_grouped_mm_cuda_v2(
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Currently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
@ -599,6 +614,21 @@ _scaled_grouped_mm_cuda_v2(
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP4_MXFP4: {
|
||||
// scale shape checks
|
||||
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _f4_f4_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0], /* block-scale A */
|
||||
std::nullopt, /* global-scale A */
|
||||
scale_b[0], /* block-scale B */
|
||||
std::nullopt, /* global-scale B */
|
||||
offs.value(),
|
||||
std::nullopt, /* bias */
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::NVFP4_NVFP4: {
|
||||
// scale shape checks
|
||||
_check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
|
||||
@ -377,7 +377,7 @@ __noinline__ __host__ __device__ scalar_t calc_igammac(scalar_t a, scalar_t x) {
|
||||
* result at the boundary
|
||||
* - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
|
||||
* Large Parameter (see DLMF 8.12.4 [igam1])
|
||||
* - if x > 1.1 and x < a, using the substraction from the regularized lower
|
||||
* - if x > 1.1 and x < a, using the subtraction from the regularized lower
|
||||
* incomplete gamma
|
||||
* - otherwise, calculate the series from [igam2] eq (5)
|
||||
*/
|
||||
@ -460,7 +460,7 @@ __noinline__ __host__ __device__ scalar_t calc_igamma(scalar_t a, scalar_t x) {
|
||||
* result at the boundary
|
||||
* - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
|
||||
* Large Parameter (see DLMF 8.12.3 [igam1])
|
||||
* - if x > 1 and x > a, using the substraction from the regularized upper
|
||||
* - if x > 1 and x > a, using the subtraction from the regularized upper
|
||||
* incomplete gamma
|
||||
* - otherwise, calculate the series from [igam2] eq (4)
|
||||
*/
|
||||
|
||||
@ -332,7 +332,7 @@ void cuda_take_put_kernel(
|
||||
const auto offset_calc = make_offset_calculator<2>(iter);
|
||||
using uindex_t = std::make_unsigned_t<index_t>;
|
||||
|
||||
// OffsetCalculator needs the sizes and strides reveresed
|
||||
// OffsetCalculator needs the sizes and strides reversed
|
||||
const auto indexed_sizes = std::vector<int64_t>(indexed.sizes().rbegin(), indexed.sizes().rend());
|
||||
const auto indexed_strides = std::vector<int64_t>(indexed.strides().rbegin(), indexed.strides().rend());
|
||||
const auto* indexed_strides_data = indexed_strides.data();
|
||||
|
||||
@ -13,7 +13,7 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
|
||||
if (allow_neg_indices) {
|
||||
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
||||
}
|
||||
CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind);
|
||||
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
|
||||
if (off >= slice_size) return;
|
||||
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
|
||||
|
||||
@ -1611,7 +1611,7 @@ void index_select_out_cuda_impl(
|
||||
|
||||
// SmallIndexKernel is more performant when the number of indices is small, and pre-loading
|
||||
// the index reduces memory accesses. When the number of indices is large, we avoid that
|
||||
// and increase parallellism by calling gather_out which is a generalization of index_select
|
||||
// and increase parallelism by calling gather_out which is a generalization of index_select
|
||||
if (cuda::detail::canUse32BitIndexMath(out) &&
|
||||
cuda::detail::canUse32BitIndexMath(self) &&
|
||||
cuda::detail::canUse32BitIndexMath(index) &&
|
||||
|
||||
@ -269,7 +269,7 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd(
|
||||
|
||||
scalar_t* dst = self_ptr + index;
|
||||
|
||||
//pack coalseced bf16 and fp16
|
||||
//pack coalesced bf16 and fp16
|
||||
if constexpr (std::is_same<scalar_t, c10::BFloat16>::value || std::is_same<scalar_t, c10::Half>::value)
|
||||
{
|
||||
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
|
||||
@ -312,7 +312,7 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd(
|
||||
}
|
||||
}
|
||||
|
||||
// not coalsced, so now let try to capture lane-matches...
|
||||
// not coalesced, so now let try to capture lane-matches...
|
||||
|
||||
if (numel > 16 /*<-hueristic threshold*/ * 64 ) {
|
||||
// well shucks, unlikely to capture same-dest atomics in a wave.
|
||||
|
||||
@ -343,7 +343,7 @@ ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
|
||||
if (input_length == 0)
|
||||
return;
|
||||
|
||||
// "first" row, the beta initialization before eq (10) (t=target_length - differes per batch)
|
||||
// "first" row, the beta initialization before eq (10) (t=target_length - differs per batch)
|
||||
for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) {
|
||||
int64_t s = threadIdx.x + block_s;
|
||||
scalar_t lb;
|
||||
|
||||
@ -816,7 +816,7 @@ const auto erfcx_string = jiterator_stringify(
|
||||
with the usual checks for overflow etcetera.
|
||||
|
||||
Performance-wise, it seems to be substantially faster than either
|
||||
the SLATEC DERFC function [or an erfcx function derived therefrom]
|
||||
the SLATEC DERFC function [or an erfcx function derived there from]
|
||||
or Cody's CALERF function (from netlib.org/specfun), while
|
||||
retaining near machine precision in accuracy.
|
||||
*/
|
||||
|
||||
@ -370,7 +370,7 @@ struct vectorized {
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// This is similar to vectorized policy above, but this one supports
|
||||
// heterogenous input tensor types as templated parameters.
|
||||
// heterogeneous input tensor types as templated parameters.
|
||||
// Its use should be limited to frequently used heterogeneous data types
|
||||
// as each instantiation will generate a separate kernel, leading to code
|
||||
// bloating if applied to all combinations supported in PyTorch. Assumption: all
|
||||
|
||||
@ -309,7 +309,7 @@ __global__ void sampleMultinomialOnce(
|
||||
} else {
|
||||
// This should address a rare bug where we don't select a valid index. This likely occurs when
|
||||
// due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but
|
||||
// and our uniform sample is greater than this value. In this case we likely have unitialized memory
|
||||
// and our uniform sample is greater than this value. In this case we likely have uninitialized memory
|
||||
// in dest[curDist]. So basically we will loop through the distribution and pick the largest index
|
||||
// where the distribution is non-zero. This is obviously terribly inefficient, but due to the
|
||||
// rarity in which this occurs, this should not be an issue.
|
||||
|
||||
@ -1654,7 +1654,7 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
|
||||
const auto stride = input.sizes()[1];
|
||||
const auto reduction_size = input.numel() / stride;
|
||||
|
||||
// Input is guarunteed to be channels-last compatible
|
||||
// Input is guaranteed to be channels-last compatible
|
||||
at::Tensor grad_input = at::empty_like(input);
|
||||
|
||||
dim3 block;
|
||||
@ -1722,7 +1722,7 @@ at::Tensor batch_norm_backward_elemt_channels_last_cuda_template(
|
||||
const auto reduction_size = input.numel() / stride;
|
||||
auto norm_fct = 1.0 / reduction_size;
|
||||
|
||||
// Input is guarunteed to be channels-last compatible
|
||||
// Input is guaranteed to be channels-last compatible
|
||||
at::Tensor grad_input = at::empty_like(input);
|
||||
|
||||
dim3 block;
|
||||
|
||||
@ -37,7 +37,7 @@ namespace at::native {
|
||||
// threshold probability for having non-duplicate keys, then it can be proved that[1]
|
||||
// the number of bits required is: ceil(log2(n - (6 n^2 + 1) / (12 log(q))))
|
||||
//
|
||||
// Then after sort, we lauch a separate kernel that additionally shuffles any islands
|
||||
// Then after sort, we launch a separate kernel that additionally shuffles any islands
|
||||
// of values whose keys matched. The algorithm of this kernel is as follows:
|
||||
// Each thread reads its key and the keys of its neighbors to tell if it's part of an island.
|
||||
// For each island, the first thread in the island sees a key match at index i+1 but not index i-1.
|
||||
|
||||
@ -1086,12 +1086,12 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
|
||||
// load instructions.
|
||||
//
|
||||
// Case 1: "vectorize along input"
|
||||
// This case happens when we are reducing along fastest moving dimesion. In such case, threads
|
||||
// This case happens when we are reducing along fastest moving dimension. In such case, threads
|
||||
// with the same threadIdx.y works on the same reduction cooperatively and will produce results
|
||||
// for the same output. In such case, values in each loaded vector always correspond to the same output.
|
||||
//
|
||||
// Case 2: "vectorize along output"
|
||||
// This case happens when the fastest moving dimesion is not the dimension of reduction. In such case,
|
||||
// This case happens when the fastest moving dimension is not the dimension of reduction. In such case,
|
||||
// threads with different threadIdx.x are independent and will produce results for different outputs.
|
||||
// In such case, values in each loaded vector always correspond to different outputs.
|
||||
if (fastest_moving_stride == sizeof(scalar_t)) {
|
||||
|
||||
@ -273,7 +273,7 @@ __global__ void reflection_pad2d_backward_det_out_kernel(
|
||||
const int64_t dist_cols = ::abs(inp_col - (input_dim_x - 1));
|
||||
|
||||
// we were dist_rows after, now we want to be dist_rows before
|
||||
// we were dist_cols before, now we wnat to be dist_cols after
|
||||
// we were dist_cols before, now we want to be dist_cols after
|
||||
const int64_t reflect_tr_out_row = (corner_tr_out_row - dist_rows);
|
||||
const int64_t reflect_tr_out_col = (corner_tr_out_col + dist_cols);
|
||||
const int64_t reflect_tr_out =
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
// Two warninngs in Cutlass included header files
|
||||
// Two warnings in Cutlass included header files
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wmissing-field-initializers")
|
||||
|
||||
@ -794,6 +794,24 @@ void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const Sc
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
_check_deepseek_support() {
|
||||
#ifndef USE_ROCM
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (dprops->major != 9) {
|
||||
// Only on Hopper GPUs
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
dprops->major == 9,
|
||||
"DeepSeek style (1x128, 128x128) scaling only supported in CUDA for SM90")
|
||||
}
|
||||
// Only in cublasLt >= 12.9
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900,
|
||||
"DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
_scaled_block1x128_block1x128(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
@ -802,8 +820,12 @@ _scaled_block1x128_block1x128(
|
||||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
@ -821,6 +843,12 @@ _scaled_block1x128_block1x128(
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -831,10 +859,12 @@ _scaled_block128x128_block1x128(
|
||||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, shape K//128
|
||||
std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl;
|
||||
std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl;
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div<int64_t>(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div<int64_t>(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat,
|
||||
@ -852,6 +882,12 @@ _scaled_block128x128_block1x128(
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
@ -862,8 +898,12 @@ _scaled_block1x128_block128x128(
|
||||
const c10::ScalarType out_dtype,
|
||||
const bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
#ifndef USE_ROCM
|
||||
// Restrictions:
|
||||
// A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128
|
||||
// CUDA: Only Hopper GPUs
|
||||
_check_deepseek_support();
|
||||
|
||||
TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ",
|
||||
mat_a.scalar_type(), mat_b.scalar_type());
|
||||
TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat,
|
||||
@ -881,6 +921,12 @@ _scaled_block1x128_block128x128(
|
||||
_scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out);
|
||||
|
||||
return out;
|
||||
#else
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"1x128 and 128x128 scaling not available with ROCm"
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
Tensor&
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
// Two warninngs in Cutlass included header files
|
||||
// Two warnings in Cutlass included header files
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-variable")
|
||||
|
||||
@ -160,8 +160,8 @@ struct _cuda_scatter_gather_internal_kernel {
|
||||
auto offsets = offset_calc.get(i);
|
||||
|
||||
int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]);
|
||||
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "scatter gather kernel index out of bounds");
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "scatter gather kernel index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim);
|
||||
|
||||
f(
|
||||
(scalar_t*)(self_ptr + offsets[0]),
|
||||
@ -406,9 +406,8 @@ struct _cuda_scatter_fill_internal_kernel {
|
||||
auto offsets = offset_calc.get(i);
|
||||
|
||||
int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]);
|
||||
CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "index out of bounds"
|
||||
);
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size
|
||||
&& "index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim);
|
||||
|
||||
f(
|
||||
(scalar_t*)(self_ptr + offsets[0]),
|
||||
|
||||
@ -460,7 +460,7 @@ __global__ void GammaBetaBackwardCUDAKernel2(
|
||||
}
|
||||
}
|
||||
|
||||
// Do warp reduce for the 2st 16 cols in the tile.
|
||||
// Do warp reduce for the 2nd 16 cols in the tile.
|
||||
sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y];
|
||||
sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y];
|
||||
sum1 = cuda_utils::WarpReduceSum<T_ACC>(sum1);
|
||||
|
||||
@ -1556,19 +1556,19 @@ NvrtcFunction jit_pwise_function(
|
||||
ss << "_" << hash_code;
|
||||
file_path = ss.str();
|
||||
|
||||
std::ifstream readin{file_path, std::ios::in | std::ifstream::binary};
|
||||
if (readin.fail()) {
|
||||
std::ifstream read_stream{file_path, std::ios::in | std::ifstream::binary};
|
||||
if (read_stream.fail()) {
|
||||
// NOTE: this does not warn because the file might not exist
|
||||
// TODO: consider if this should explicitly check for the file's existence or not to throw
|
||||
// an informative warning
|
||||
readin.close();
|
||||
read_stream.close();
|
||||
} else {
|
||||
// TODO: try passing the "mapped" file directly to cuModuleLoadCall instead of using an intermediate buffer
|
||||
std::vector<char> buffer(std::istreambuf_iterator<char>(readin), {});
|
||||
std::vector<char> buffer(std::istreambuf_iterator<char>(read_stream), {});
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc.cuModuleLoadData(&(compiled_kernel_.module), buffer.data()));
|
||||
AT_CUDA_DRIVER_CHECK(
|
||||
nvrtc.cuModuleGetFunction(&(compiled_kernel_.function), compiled_kernel_.module, name.c_str()));
|
||||
readin.close();
|
||||
read_stream.close();
|
||||
return compiled_kernel_;
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,7 +141,8 @@ WelfordDataLN cuWelfordOnlineSum(
|
||||
if constexpr (!rms_norm){
|
||||
U delta = val - curr_sum.mean;
|
||||
U new_count = curr_sum.count + 1.f;
|
||||
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf`
|
||||
#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
|
||||
#else
|
||||
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
|
||||
@ -163,7 +164,8 @@ WelfordDataLN cuWelfordCombine(
|
||||
U count = dataA.count + dataB.count;
|
||||
U mean, sigma2;
|
||||
if (count > decltype(dataB.count){0}) {
|
||||
#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf`
|
||||
#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL)
|
||||
auto coef = __builtin_amdgcn_rcpf(count);
|
||||
#else
|
||||
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
|
||||
@ -1050,7 +1052,7 @@ void launch_vectorized_layer_norm_kernel(
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// the blocks.x contains the max grid x dimention without invalid configuration error
|
||||
// the blocks.x contains the max grid x dimension without invalid configuration error
|
||||
// Fix invalid configuration https://github.com/pytorch/pytorch/issues/136291
|
||||
// Ensure all elements are processed. Prepare for next round
|
||||
int64_t remaining = M - blocks.x;
|
||||
|
||||
@ -177,7 +177,7 @@ bool use_ragged_in_dense(
|
||||
TORCH_WARN_ONCE(
|
||||
"TORCH_CUDNN_SDPA_AVOID_RECOMPILE=1 only works with Q, K, V, and output in BSHD memory layout,"
|
||||
"e.g., Q, K, V must be allocated with torch.randn((B, S, H, D).transpose(1, 2)."
|
||||
"Falling back to regualr dense case, which may trigger excessive recompilation.");
|
||||
"Falling back to regular dense case, which may trigger excessive recompilation.");
|
||||
}
|
||||
return all_bshd;
|
||||
}
|
||||
@ -771,7 +771,7 @@ std::unique_ptr<fe::graph::Graph> build_graph_nestedtensor(
|
||||
if (attn_bias.has_value()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
|
||||
"attn_bias not yet supported with cuDNN Attention and NestedTensor");
|
||||
scaled_dot_product_flash_attention_options.set_bias(
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_uid(BIAS)
|
||||
@ -1196,7 +1196,7 @@ std::unique_ptr<fe::graph::Graph> build_graph_backward_nestedtensor(
|
||||
if (attn_bias.has_value()) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
|
||||
"attn_bias not yet supported with cuDNN Attention and NestedTensor");
|
||||
sdpa_backward_options.set_bias(
|
||||
mha_graph->tensor(fe::graph::Tensor_attributes()
|
||||
.set_uid(BIAS)
|
||||
@ -1864,7 +1864,7 @@ void run_cudnn_SDP_bprop_nestedtensor(
|
||||
}
|
||||
TORCH_CHECK(
|
||||
!attn_bias.has_value(),
|
||||
"attn_bias not yet supportd with cuDNN Attention and NestedTensor");
|
||||
"attn_bias not yet supported with cuDNN Attention and NestedTensor");
|
||||
|
||||
auto workspace_size = mha_graph.get_workspace_size();
|
||||
auto workspace_ptr =
|
||||
|
||||
@ -30,7 +30,7 @@ static const std::unordered_map<
|
||||
};
|
||||
|
||||
|
||||
// This is the heursitic to choose a kernel based on inputs
|
||||
// This is the heuristic to choose a kernel based on inputs
|
||||
BGEMMKernel_BFloat16 dispatch_bfloat16_bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
// Optional/future use: directly lookup shape tuples to map to instances
|
||||
/*
|
||||
|
||||
@ -11,7 +11,7 @@ using S = ck::Sequence<Is...>;
|
||||
namespace at::native {
|
||||
|
||||
void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
// If any of the shapes can't be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
@ -471,7 +471,7 @@ void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
|
||||
void dispatch_bfloat16_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
// If any of the shapes can't be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
|
||||
@ -11,7 +11,7 @@ using S = ck::Sequence<Is...>;
|
||||
namespace at::native {
|
||||
|
||||
void dispatch_float_gemm(CUDABLAS_GEMM_ARGTYPES(float)) {
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
// If any of the shapes can't be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
|
||||
@ -13,7 +13,7 @@ namespace at::native {
|
||||
|
||||
void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
#if 0
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
// If any of the shapes can't be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
@ -299,7 +299,7 @@ void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
#endif
|
||||
}
|
||||
void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
// If any of the shapes cant be tiled, we must use padding.
|
||||
// If any of the shapes can't be tiled, we must use padding.
|
||||
bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0));
|
||||
// Dispatch to best implementation.
|
||||
// TODO add more configurations. Optimize.
|
||||
|
||||
@ -545,7 +545,7 @@ kernel void reshape(texture2d_array<half, access::read> in_arr[[texture(0), func
|
||||
const ushort slices2 = divRoundUp(C2, 4);
|
||||
const ushort slices1 = divRoundUp(C1, 4);
|
||||
const ushort n2 = gid.z / slices2; //image index
|
||||
const ushort s2 = gid.z - n2 * slices2; // slice offest
|
||||
const ushort s2 = gid.z - n2 * slices2; // slice offset
|
||||
half4 value;
|
||||
for (int idx = 0; idx < 4; ++idx){
|
||||
// we compute the "linear index" of the output element,
|
||||
|
||||
@ -86,4 +86,4 @@ TORCH_LIBRARY_IMPL(aten, Metal, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardsigmoid_"), TORCH_FN(hardsigmoid_));
|
||||
}
|
||||
|
||||
} // namepsace at::native::metal
|
||||
} // namespace at::native::metal
|
||||
|
||||
@ -34,7 +34,7 @@ namespace at::native::onednn {
|
||||
|
||||
/*
|
||||
oneDNN postops usage:
|
||||
Currently, oneDNN supports 5 kinds of post ops. More details can be refered
|
||||
Currently, oneDNN supports 5 kinds of post ops. More details can be referred
|
||||
to oneDNN doc.
|
||||
https://oneapi-src.github.io/oneDNN/dev_guide_attributes_post_ops.html#doxid-dev-guide-attributes-post-ops-1dev-guide-attributes-post-ops-eltwise
|
||||
|
||||
@ -399,7 +399,7 @@ static inline void construct_attr_for_unary(
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
unary_post_op == "none",
|
||||
"onednn qlinear: unspported unary post op",
|
||||
"onednn qlinear: unsupported unary post op",
|
||||
unary_post_op);
|
||||
}
|
||||
}
|
||||
|
||||
@ -856,7 +856,7 @@ id<MTLLibrary> MetalShaderLibrary::getLibrary(const std::initializer_list<std::s
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_INTERNAL_ASSERT(false, "Unsupported number of paramaters ", nparams);
|
||||
TORCH_INTERNAL_ASSERT(false, "Unsupported number of parameters ", nparams);
|
||||
}
|
||||
return libMap[key] = lib;
|
||||
}
|
||||
@ -1184,9 +1184,9 @@ void MetalKernelFunction::dispatch(uint64_t length, std::optional<uint64_t> grou
|
||||
}
|
||||
|
||||
void MetalKernelFunction::dispatch(c10::ArrayRef<uint64_t> length, c10::OptionalArrayRef<uint64_t> group_size) {
|
||||
TORCH_CHECK(!length.empty() && length.size() < 4, "Dispatch dimentions must be less than 3 and non-empty");
|
||||
TORCH_CHECK(!length.empty() && length.size() < 4, "Dispatch dimensions must be less than 3 and non-empty");
|
||||
TORCH_CHECK(!group_size.has_value() || group_size->size() == length.size(),
|
||||
"size and group_size must have same number of dimentions");
|
||||
"size and group_size must have same number of dimensions");
|
||||
const auto max_tg_size = getMaxThreadsPerThreadgroup();
|
||||
const auto group_size_length = group_size.has_value() ? group_size->size() : 0;
|
||||
auto tg_size = MTLSizeMake(group_size_length > 0 ? group_size->at(0) : max_tg_size,
|
||||
|
||||
@ -59,7 +59,7 @@ static GridSamplerOffsets find_grid_sampler_offsets(
|
||||
return offsets;
|
||||
}
|
||||
|
||||
// Mod function which gives postive output when `a` is negative
|
||||
// Mod function which gives positive output when `a` is negative
|
||||
static int32_t mod(int32_t a, int32_t b) {
|
||||
auto r = a % b;
|
||||
return r + (r < 0 ? b : 0);
|
||||
@ -191,9 +191,9 @@ void grid_sampler_single_element(
|
||||
int32_t right_indices[3];
|
||||
opmath_t<T> scales[3];
|
||||
|
||||
// For each dimension, find the pair of indices in the cooresponding dimension
|
||||
// For each dimension, find the pair of indices in the corresponding dimension
|
||||
// of `input` which surround the grid coordinate in that dimension. We'll do
|
||||
// this by mapping different coordiante spaces onto each other. There are
|
||||
// this by mapping different coordinate spaces onto each other. There are
|
||||
// basically three different coordinate spaces to keep in mind:
|
||||
//
|
||||
// * aligned grid space
|
||||
|
||||
@ -137,7 +137,7 @@ kernel void index_put_serial(
|
||||
constant int64_t* index_strides,
|
||||
constant uint4& ndim_nindices_numel,
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
(void)thread_index; // Suppress unused vairable varning
|
||||
(void)thread_index; // Suppress unused variable warning
|
||||
for (uint idx = 0; idx < ndim_nindices_numel.z; ++idx) {
|
||||
index_put_impl(
|
||||
output,
|
||||
|
||||
@ -112,7 +112,7 @@ kernel void int4pack_mm(constant T *A [[buffer(0)]],
|
||||
constant uchar *B_ptr = B + ((n * K) / k_pack_factor);
|
||||
|
||||
thread float4 result = float4(0.0);
|
||||
// We multipy group of 4 channels with these scales.
|
||||
// We multiply group of 4 channels with these scales.
|
||||
// Because corresponding values from weight matrix are effectively left
|
||||
// shifted. This is to avoid doing right shift on those values which ends up
|
||||
// affecting performance. This is the trick applied in MLX kernels.
|
||||
|
||||
@ -387,7 +387,7 @@ struct log1p_functor {
|
||||
}
|
||||
template <typename T>
|
||||
inline enable_if_t<is_complex_v<T>, T> operator()(const T x) {
|
||||
// TODO: Implement proper log1p algoirthm
|
||||
// TODO: Implement proper log1p algorithm
|
||||
auto magnitude = ::precise::sqrt((1.0f + x.x) * (1.0f + x.x) + x.y * x.y);
|
||||
auto real = ::precise::log(magnitude);
|
||||
auto imag = (x.x == -1 && x.y == 0) ? 0 : ::precise::atan2(x.y, 1.0 + x.x);
|
||||
|
||||
@ -448,7 +448,7 @@ kernel void upsample_trilinear_backward(
|
||||
|
||||
// See Note [ Weights computation for uint8_t and multiplication trick ]
|
||||
// Essentially fall back to fixed floating point arithmetic during uint8
|
||||
// interpolation, which is not necesserily more accurate (see example below),
|
||||
// interpolation, which is not necessarily more accurate (see example below),
|
||||
// but matches closes to what CPU can deliver
|
||||
// I.e. mid-point 152+249+172+35 is 152, but algorithm yields 153 as horizontal
|
||||
// and vertical interpolation is done in separate steps and results are rounded
|
||||
|
||||
@ -41,7 +41,7 @@ Tensor pad_tensor_to_shape(
|
||||
const Tensor& t,
|
||||
IntArrayRef goal_shape,
|
||||
double value = 0) {
|
||||
std::vector<int64_t> padd;
|
||||
std::vector<int64_t> padding;
|
||||
auto tup = t.sizes();
|
||||
TORCH_CHECK(
|
||||
t.dim() == (int64_t)(goal_shape.size()),
|
||||
@ -51,10 +51,10 @@ Tensor pad_tensor_to_shape(
|
||||
goal_shape.size(),
|
||||
" of goal shape.");
|
||||
for (int64_t i = static_cast<int64_t>(tup.size()) - 1; i >= 0; i--) {
|
||||
padd.push_back(0);
|
||||
padd.push_back(goal_shape[i] - tup[i]);
|
||||
padding.push_back(0);
|
||||
padding.push_back(goal_shape[i] - tup[i]);
|
||||
}
|
||||
Tensor new_tensor = at::constant_pad_nd(t, IntArrayRef(padd), value);
|
||||
Tensor new_tensor = at::constant_pad_nd(t, IntArrayRef(padding), value);
|
||||
new_tensor = new_tensor.reshape(goal_shape);
|
||||
return new_tensor;
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// Implementation of specal math functions for Metal
|
||||
// Implementation of special math functions for Metal
|
||||
#pragma once
|
||||
#include <c10/metal/expm1f.h>
|
||||
#include <c10/metal/igamma.h>
|
||||
|
||||
@ -34,7 +34,7 @@ struct MemEvent {
|
||||
bool overlaps(const MemBlock& a, const MemBlock& b) {
|
||||
// two blocks dont overlap if
|
||||
// |---a--------|--------------b--------|
|
||||
// strat_a end_a <= start_b end_b
|
||||
// start_a end_a <= start_b end_b
|
||||
return !(
|
||||
(a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset));
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ struct bitset final {
|
||||
constexpr bitset() noexcept = default;
|
||||
constexpr bitset(const bitset&) noexcept = default;
|
||||
constexpr bitset(bitset&&) noexcept = default;
|
||||
// there is an issure for gcc 5.3.0 when define default function as constexpr
|
||||
// there is an issue for gcc 5.3.0 when define default function as constexpr
|
||||
// see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
|
||||
bitset& operator=(const bitset&) noexcept = default;
|
||||
bitset& operator=(bitset&&) noexcept = default;
|
||||
|
||||
@ -554,6 +554,17 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
@ -724,6 +735,11 @@ class XPUAllocator : public DeviceAllocator {
|
||||
device_allocators[device]->resetAccumulatedStats();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
TORCH_CHECK_VALUE(
|
||||
@ -777,6 +793,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) {
|
||||
return allocator.recordStream(dataPtr, stream);
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
return allocator.getMemoryFraction(device);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction, DeviceIndex device) {
|
||||
return allocator.setMemoryFraction(fraction, device);
|
||||
}
|
||||
|
||||
@ -25,6 +25,8 @@ C10_XPU_API void raw_delete(void* ptr);
|
||||
|
||||
C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream);
|
||||
|
||||
C10_XPU_API double getMemoryFraction(DeviceIndex device);
|
||||
|
||||
C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device);
|
||||
|
||||
} // namespace c10::xpu::XPUCachingAllocator
|
||||
|
||||
@ -38,7 +38,7 @@ uint32_t crc32_combine (uint32_t crcA, uint32_t crcB, size_t lengthB);
|
||||
|
||||
/// compute CRC32 (bitwise algorithm)
|
||||
uint32_t crc32_bitwise (const void* data, size_t length, uint32_t previousCrc32 = 0);
|
||||
/// compute CRC32 (half-byte algoritm)
|
||||
/// compute CRC32 (half-byte algorithm)
|
||||
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32 = 0);
|
||||
|
||||
#ifdef CRC32_USE_LOOKUP_TABLE_BYTE
|
||||
@ -96,7 +96,7 @@ uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previo
|
||||
#define __BIG_ENDIAN 4321
|
||||
#endif
|
||||
|
||||
// define endianess and some integer data types
|
||||
// define endianness and some integer data types
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
// Windows always little endian
|
||||
#define __BYTE_ORDER __LITTLE_ENDIAN
|
||||
@ -168,7 +168,7 @@ namespace
|
||||
/// zlib's CRC32 polynomial
|
||||
const uint32_t Polynomial = 0xEDB88320;
|
||||
|
||||
/// swap endianess
|
||||
/// swap endianness
|
||||
static inline uint32_t swap(uint32_t x)
|
||||
{
|
||||
#if defined(__GNUC__) || defined(__clang__)
|
||||
@ -229,7 +229,7 @@ uint32_t crc32_bitwise(const void* data, size_t length, uint32_t previousCrc32)
|
||||
}
|
||||
|
||||
|
||||
/// compute CRC32 (half-byte algoritm)
|
||||
/// compute CRC32 (half-byte algorithm)
|
||||
uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32)
|
||||
{
|
||||
uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF
|
||||
@ -662,7 +662,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB)
|
||||
// - if you append length(B) zeros to A and call it A' (think of it as AAAA000)
|
||||
// and prepend length(A) zeros to B and call it B' (think of it as 0000BBB)
|
||||
// then exists a C' = A' ^ B'
|
||||
// - remember: if you XOR someting with zero, it remains unchanged: X ^ 0 = X
|
||||
// - remember: if you XOR something with zero, it remains unchanged: X ^ 0 = X
|
||||
// - that means C' = A concat B so that crc(A concat B) = crc(C') = crc(A') ^ crc(B')
|
||||
// - the trick is to compute crc(A') based on crc(A)
|
||||
// and crc(B') based on crc(B)
|
||||
|
||||
@ -76,7 +76,7 @@ typedef struct mz_zip_archive mz_zip_archive;
|
||||
// 2) Writing with 1-pass sequential access
|
||||
// -> We must take care not to require updating values that have already
|
||||
// been written. We place the variable-length index at the end and do
|
||||
// not put any indicies into the header to fulfill this constraint.
|
||||
// not put any index into the header to fulfill this constraint.
|
||||
|
||||
// The model.json, which contains all the metadata information,
|
||||
// should be written as the last file. One reason is that the size of tensor
|
||||
|
||||
@ -519,7 +519,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoadWithAllocator) {
|
||||
std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator);
|
||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes);
|
||||
// allcoate with base allocator
|
||||
// allocate with base allocator
|
||||
std::tie(data_ptr, size) = reader.getRecord("key1");
|
||||
EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1);
|
||||
EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1);
|
||||
|
||||
@ -76,6 +76,7 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_per_process_memory_fraction
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
mem_get_info
|
||||
|
||||
2
setup.py
2
setup.py
@ -1106,7 +1106,7 @@ class build_ext(setuptools.command.build_ext.build_ext):
|
||||
continue
|
||||
self.copy_file(source_lib, target_lib)
|
||||
# Delete old rpath and add @loader_lib to the rpath
|
||||
# This should prevent delocate from attempting to package another instance
|
||||
# This should prevent deallocate from attempting to package another instance
|
||||
# of OpenMP library in torch wheel as well as loading two libomp.dylib into
|
||||
# the address space, as libraries are cached by their unresolved names
|
||||
install_name_tool_args = [
|
||||
|
||||
@ -827,7 +827,7 @@ class TestFullyShardShardPlacementFnMultiProcess(FSDPTest):
|
||||
|
||||
torch.manual_seed(42 + self.rank)
|
||||
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
|
||||
for iter_idx in range(5):
|
||||
for _ in range(5):
|
||||
ref_loss = ref_model(inp).sum()
|
||||
loss = model(inp).sum()
|
||||
self.assertEqual(ref_loss, loss)
|
||||
|
||||
@ -52,7 +52,7 @@ class TestComplement(TestCase):
|
||||
|
||||
_LOGGER.debug(f"{layout} => {layoutR}")
|
||||
|
||||
# Post-condition: test disjointness of the codomains
|
||||
# Post-condition: test disjointedness of the codomains
|
||||
for a in range(size(layout)):
|
||||
for b in range(size(layoutR)):
|
||||
assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0)
|
||||
|
||||
@ -31,17 +31,17 @@ if TEST_WITH_DEV_DBG_ASAN:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
_DISTRIBUTED_STATE_DICT_IMPLS = (
|
||||
_DISTRIBUTED_STATE_DICT_IMPLS = {
|
||||
StateDictType.LOCAL_STATE_DICT,
|
||||
StateDictType.SHARDED_STATE_DICT,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class TestDistributedCheckpoint(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
if torch.cuda.is_available():
|
||||
gpu_cnt = torch.cuda.device_count()
|
||||
if torch.accelerator.is_available():
|
||||
gpu_cnt = torch.accelerator.device_count()
|
||||
if gpu_cnt < 2:
|
||||
return gpu_cnt
|
||||
return 2
|
||||
@ -93,7 +93,9 @@ class TestDistributedCheckpoint(FSDPTest):
|
||||
# TODO: add resharding test case.
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(TestDistributedCheckpoint, globals(), only_for=devices)
|
||||
devices = ("cuda", "hpu", "xpu")
|
||||
instantiate_device_type_tests(
|
||||
TestDistributedCheckpoint, globals(), only_for=devices, allow_xpu=True
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -36,8 +36,8 @@ device_type = torch.device(get_devtype())
|
||||
class TestApply(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
if torch.cuda.is_available():
|
||||
gpu_cnt = torch.cuda.device_count()
|
||||
if torch.accelerator.is_available():
|
||||
gpu_cnt = torch.accelerator.device_count()
|
||||
if gpu_cnt < 2:
|
||||
return gpu_cnt
|
||||
return 2
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -45,53 +44,19 @@ class TestInstantiator(TestCase):
|
||||
self.assertEqual(return_type_str, "Tuple[Tensor, int, str]")
|
||||
|
||||
def test_instantiate_scripted_remote_module_template(self):
|
||||
dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
|
||||
|
||||
# Cleanup.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
for file_path in file_paths:
|
||||
file_path.unlink()
|
||||
|
||||
# Check before run.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
num_files_before = len(list(file_paths))
|
||||
self.assertEqual(num_files_before, 0)
|
||||
|
||||
generated_module = instantiator.instantiate_scriptable_remote_module_template(
|
||||
MyModuleInterface
|
||||
)
|
||||
self.assertTrue(hasattr(generated_module, "_remote_forward"))
|
||||
self.assertTrue(hasattr(generated_module, "_generated_methods"))
|
||||
|
||||
# Check after run.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
num_files_after = len(list(file_paths))
|
||||
self.assertEqual(num_files_after, 1)
|
||||
|
||||
def test_instantiate_non_scripted_remote_module_template(self):
|
||||
dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
|
||||
|
||||
# Cleanup.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
for file_path in file_paths:
|
||||
file_path.unlink()
|
||||
|
||||
# Check before run.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
num_files_before = len(list(file_paths))
|
||||
self.assertEqual(num_files_before, 0)
|
||||
|
||||
generated_module = (
|
||||
instantiator.instantiate_non_scriptable_remote_module_template()
|
||||
)
|
||||
self.assertTrue(hasattr(generated_module, "_remote_forward"))
|
||||
self.assertTrue(hasattr(generated_module, "_generated_methods"))
|
||||
|
||||
# Check after run.
|
||||
file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
|
||||
num_files_after = len(list(file_paths))
|
||||
self.assertEqual(num_files_after, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -64,6 +64,10 @@ class TestDTensorDebugMode(TestCase):
|
||||
self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall))
|
||||
self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default)
|
||||
|
||||
# check stringification
|
||||
self.assertTrue(hasattr(debug_mode.operators[0], "args_str"))
|
||||
self.assertFalse(hasattr(debug_mode.operators[0], "args"))
|
||||
|
||||
def test_debug_string_inside_context(self):
|
||||
mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
||||
@ -267,6 +271,7 @@ class TestDTensorDebugMode(TestCase):
|
||||
record_torchfunction=True,
|
||||
record_faketensor=True,
|
||||
record_tensor_attributes=["a1", "a2"],
|
||||
store_original_args=True,
|
||||
) as debug_mode:
|
||||
torch.matmul(y, x)
|
||||
|
||||
@ -279,6 +284,9 @@ class TestDTensorDebugMode(TestCase):
|
||||
aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""",
|
||||
)
|
||||
|
||||
self.assertTrue(hasattr(debug_mode.operators[0], "args"))
|
||||
self.assertEqual(id(debug_mode.operators[0].args[0]), id(y))
|
||||
|
||||
@parametrize("has_inner_mode", [True, False])
|
||||
@parametrize("has_outer_mode", [True, False])
|
||||
def test_nested_debug_mode(self, has_inner_mode, has_outer_mode):
|
||||
|
||||
@ -20,18 +20,18 @@ from torch.distributed.tensor.experimental._attention import (
|
||||
_cp_options,
|
||||
_disable_context_parallel_dispatcher,
|
||||
_enable_context_parallel_dispatcher,
|
||||
_HeadTailLoadBalancer,
|
||||
_is_causal_behavior,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
_RotateMethod,
|
||||
context_parallel,
|
||||
context_parallel_unshard,
|
||||
set_rotate_method,
|
||||
)
|
||||
from torch.distributed.tensor.experimental._cp_custom_ops import flex_cp_allgather
|
||||
from torch.distributed.tensor.experimental._load_balancer import (
|
||||
_HeadTailLoadBalancer,
|
||||
_LoadBalancer,
|
||||
_PerDocumentHeadTailLoadBalancer,
|
||||
_PTRRLoadBalancer,
|
||||
from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import (
|
||||
flex_cp_allgather,
|
||||
)
|
||||
from torch.distributed.tensor.parallel import parallelize_module
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
@ -52,7 +52,9 @@ from torch.testing._internal.common_cuda import (
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests, skipIfRocm
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
map_local_tensor_for_rank,
|
||||
with_comms,
|
||||
)
|
||||
|
||||
@ -800,11 +802,47 @@ class TestSharding(DTensorTestBase):
|
||||
chunks = freqs_cis.chunk(self.world_size * 2)
|
||||
self.assertEqual(
|
||||
freqs_cis_shard,
|
||||
torch.cat(
|
||||
[chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0
|
||||
map_local_tensor_for_rank(
|
||||
chunks,
|
||||
self.rank,
|
||||
lambda chunks, rank: torch.cat(
|
||||
[chunks[rank], chunks[self.world_size * 2 - rank - 1]],
|
||||
dim=0,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
RingAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
||||
RingAttentionTest,
|
||||
skipped_tests=[
|
||||
# Need to make attention implementation local tensor friendly, e.g.
|
||||
# rewrite "rank local" logic
|
||||
"test_ring_attention_sdpa",
|
||||
],
|
||||
)
|
||||
|
||||
CPFlexAttentionTestWithLocalTensor = create_local_tensor_test_class(
|
||||
CPFlexAttentionTest,
|
||||
skipped_tests=[
|
||||
# Missing support for batched tensors
|
||||
"test_cp_flex_attention_causal_mask",
|
||||
"test_cp_flex_attention_document_mask",
|
||||
],
|
||||
)
|
||||
|
||||
TestCPCustomOpsWithLocalTensor = create_local_tensor_test_class(
|
||||
TestCPCustomOps,
|
||||
skipped_tests=[
|
||||
# Missing support for fake tensors
|
||||
"test_flex_cp_custom_op",
|
||||
],
|
||||
)
|
||||
|
||||
TestShardingWithLocalTensor = create_local_tensor_test_class(
|
||||
TestSharding,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -16,6 +16,7 @@ from torch.distributed.tensor import (
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
@ -232,5 +233,17 @@ class DistConvolutionOpsTest(DTensorTestBase):
|
||||
self.assertEqual(out_dt.shape, out.shape)
|
||||
|
||||
|
||||
DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class(
|
||||
DistConvolutionOpsTest,
|
||||
# Send / recv ops are not supported
|
||||
skipped_tests=[
|
||||
"test_conv1d",
|
||||
"test_conv3d",
|
||||
"test_conv_backward_none_grad_inp",
|
||||
"test_depthwise_convolution",
|
||||
"test_downsampling_convolution",
|
||||
],
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -520,6 +520,21 @@ class DTensorExportTest(TestCase):
|
||||
2,
|
||||
)
|
||||
|
||||
def test_union_typed_annotation(self):
|
||||
def fn(leaf: torch.Tensor | DTensor):
|
||||
def nest_fn(leaf: torch.Tensor | DTensor):
|
||||
# def nest_fn(leaf: Union[torch.Tensor, DTensor]): # this works
|
||||
if isinstance(leaf, DTensor):
|
||||
leaf = leaf.to_local()
|
||||
return leaf
|
||||
|
||||
return nest_fn(leaf) + 1
|
||||
|
||||
z = torch.randn(16, 16)
|
||||
gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,))
|
||||
|
||||
self.assertEqual(fn(z), gm(z)[0])
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
||||
@ -352,7 +352,7 @@ graph():
|
||||
self.rank, self.world_size, self.backend(device_type), fake_pg=True
|
||||
):
|
||||
# all_reduces remain in order!
|
||||
# note: this isnt actually invariant of pass currently..
|
||||
# note: this isn't actually invariant of pass currently..
|
||||
# but we should keep collectives stable without reordering opportunities
|
||||
|
||||
_, code = run_and_get_aten_graph(fn, g1, g2, g3)
|
||||
@ -887,6 +887,135 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
correct = func(a, b, c, d, ranks=ranks)
|
||||
self.assertTrue(same(test_out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_custom_estimation_with_fake_tensor_mode(self):
|
||||
"""Test that custom estimation can use FakeTensorMode for analysis."""
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
estimation_calls = 0
|
||||
|
||||
def estimate_with_fake_mode(fx_node, compute_multiplier=1.0):
|
||||
with FakeTensorMode():
|
||||
nonlocal estimation_calls
|
||||
estimation_calls += 1
|
||||
assert isinstance(torch.rand([20]), torch._subclasses.FakeTensor)
|
||||
|
||||
return 1.0
|
||||
|
||||
patches = get_bucket_patches()
|
||||
patches["aten_distributed_optimizations.custom_runtime_estimation"] = (
|
||||
estimate_with_fake_mode
|
||||
)
|
||||
|
||||
def func(a, b, *, ranks):
|
||||
# Two independent all_gathers that should be bucketed
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
|
||||
|
||||
# Matmul that can hide the collectives
|
||||
mm1 = torch.matmul(a, a)
|
||||
|
||||
return ag1.sum() + ag2.sum() + mm1.sum()
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type)
|
||||
inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
with torch._inductor.config.patch(patches):
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(
|
||||
compiled, inputs_a, inputs_b
|
||||
)
|
||||
|
||||
# Verify the custom estimation was called
|
||||
self.assertTrue(
|
||||
estimation_calls > 0, "Custom estimation should have been called"
|
||||
)
|
||||
|
||||
correct = func(inputs_a, inputs_b, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_multidtype_bucketing(self):
|
||||
"""Test that all_gathers with different dtypes get bucketed together."""
|
||||
|
||||
def func(a, b, c, *, ranks):
|
||||
# Three all_gathers with different dtypes
|
||||
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) # float32
|
||||
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) # float16
|
||||
ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) # float16
|
||||
|
||||
# Use all results
|
||||
return ag1.sum() + ag2.sum() + ag3.sum()
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(4, 4, dtype=torch.float32, device=device_type)
|
||||
b = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 2
|
||||
c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3
|
||||
ranks = list(range(self.world_size))
|
||||
|
||||
func_c = functools.partial(func, ranks=ranks)
|
||||
compiled = torch.compile(func_c)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
|
||||
|
||||
# Should have 1 bucketed all_gather despite different dtypes
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify correctness
|
||||
correct = func(a, b, c, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_basic_all_reduce_bucketing(self):
|
||||
"""Test that independent all_reduce operations get bucketed together."""
|
||||
|
||||
def func(a, b, c):
|
||||
# Three independent all_reduces that should be bucketed
|
||||
ar1 = _functional_collectives.all_reduce(a, "sum", "0")
|
||||
ar2 = _functional_collectives.all_reduce(b, "sum", "0")
|
||||
ar3 = _functional_collectives.all_reduce(c, "sum", "0")
|
||||
|
||||
return ar1.sum() + ar2.sum() + ar3.sum()
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
a = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank
|
||||
b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2
|
||||
c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3
|
||||
|
||||
compiled = torch.compile(func)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c)
|
||||
|
||||
# Should see a single bucketed all_reduce
|
||||
FileCheck().check_count(
|
||||
"torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Verify correctness
|
||||
correct = func(a, b, c)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -414,6 +414,7 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||
|
||||
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
||||
model = Model().to(self.device)
|
||||
model.emb.weight.requires_grad = False
|
||||
model_compiled = torch.compile(model)
|
||||
inp = torch.tensor([[2, 1, 3, 0]], dtype=torch.long, device=self.device)
|
||||
out = model_compiled(inp, self.world_size, **self.get_world_trs())
|
||||
@ -1340,13 +1341,11 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
assert counter.op_count == 3 # It generates 2 getattr to unpack the array
|
||||
assert same(out, correct)
|
||||
|
||||
# This doesn't work in all cases, and now we properly loudly error.
|
||||
# See: https://github.com/pytorch/pytorch/issues/151240
|
||||
# When differentiable funcols are implemented can revert.
|
||||
@unittest.expectedFailure
|
||||
def test_backwards(self):
|
||||
"""
|
||||
It's probably not that common to need backwards support for collectives.
|
||||
|
||||
However, I wanted to at least see if it was possible to support it as a design goal.
|
||||
"""
|
||||
|
||||
def func(inp):
|
||||
ar = _functional_collectives.all_reduce(inp, "sum", "0")
|
||||
return ar
|
||||
@ -1672,7 +1671,7 @@ class TestCollectivesInductor(DynamoDistributedSingleProcTestCase):
|
||||
compiled = torch.compile(func)
|
||||
code = run_and_get_triton_code(compiled, *inputs, **self.get_world_trs())
|
||||
|
||||
# shouldnt have bucketed
|
||||
# shouldn't have bucketed
|
||||
FileCheck().check_count("wait_tensor.default(", 2, exactly=True).run(code)
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
|
||||
572
test/distributed/test_overlap_bucketing_unit.py
Normal file
572
test/distributed/test_overlap_bucketing_unit.py
Normal file
@ -0,0 +1,572 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.logging
|
||||
import torch._dynamo.test_case
|
||||
import torch.distributed as dist
|
||||
import torch.fx as fx
|
||||
|
||||
# for some reason importing functional collectives after dynamo breaks collectives handling!
|
||||
from torch._C import FileCheck
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_distributed import requires_accelerator_dist_backend
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
# flake8: noqa: B950
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
from torch.testing._internal.common_fsdp import get_devtype
|
||||
|
||||
|
||||
device_type = str(get_devtype())
|
||||
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.logging
|
||||
import torch._dynamo.test_case
|
||||
|
||||
|
||||
# for some reason importing functional collectives after dynamo breaks collectives handling!
|
||||
|
||||
|
||||
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
||||
def build_collective_info(graph, hiding_annotations):
|
||||
"""
|
||||
Build CollectiveInfo dict from manual hiding annotations.
|
||||
|
||||
hiding_annotations: dict mapping collective_start -> hiding_compute_node
|
||||
"""
|
||||
from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo
|
||||
|
||||
collective_info = {}
|
||||
|
||||
# Find all collective starts and their corresponding waits
|
||||
start_to_wait = {}
|
||||
for node in graph.nodes:
|
||||
if node.op == "call_function" and "wait_tensor" in str(node.target):
|
||||
wait_input = node.args[0]
|
||||
if isinstance(wait_input, fx.Node):
|
||||
start_to_wait[wait_input] = node
|
||||
|
||||
# Build CollectiveInfo for each collective
|
||||
for start_node, wait_node in start_to_wait.items():
|
||||
hiding_node = hiding_annotations.get(start_node)
|
||||
|
||||
# Estimate size and time
|
||||
size_bytes = 16 * 4 # 4x4 tensor of floats
|
||||
estimated_time_ms = 1.0 # Dummy time
|
||||
exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node
|
||||
|
||||
collective_info[start_node] = CollectiveInfo(
|
||||
start_node=start_node,
|
||||
wait_node=wait_node,
|
||||
size_bytes=size_bytes,
|
||||
estimated_time_ms=estimated_time_ms,
|
||||
exposed_time_ms=exposed_time_ms,
|
||||
hiding_node=hiding_node,
|
||||
)
|
||||
|
||||
return collective_info
|
||||
|
||||
|
||||
def compute_ancestors(graph):
|
||||
"""Compute ancestor sets for all nodes in the graph."""
|
||||
node_ancestors = {}
|
||||
|
||||
for node in graph.nodes:
|
||||
ancestors = OrderedSet()
|
||||
stack = list(node.all_input_nodes)
|
||||
visited = set()
|
||||
|
||||
while stack:
|
||||
current = stack.pop()
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
ancestors.add(current)
|
||||
stack.extend(current.all_input_nodes)
|
||||
|
||||
node_ancestors[node] = ancestors
|
||||
|
||||
return node_ancestors
|
||||
|
||||
|
||||
@requires_accelerator_dist_backend()
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@instantiate_parametrized_tests
|
||||
class TestOverlapPreservingBucketing(InductorTestCase):
|
||||
"""
|
||||
Unit tests for overlap-preserving bucketing pass.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
store = FakeStore()
|
||||
dist.init_process_group(backend="fake", rank=0, world_size=2, store=store)
|
||||
cls.device = "cuda"
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
dist.destroy_process_group()
|
||||
|
||||
def test_can_bucket_independent_collectives(self):
|
||||
"""
|
||||
Test that independent collectives with separate hiding nodes CAN bucket.
|
||||
|
||||
Graph structure:
|
||||
ag1_start -> ag2_start -> mm1 (hides ag1) -> mm2 (hides ag2) -> ag1_wait -> ag2_wait
|
||||
"""
|
||||
|
||||
def func(a, b):
|
||||
group_name = "0"
|
||||
group_size = 1
|
||||
|
||||
# Start both collectives
|
||||
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
a, group_size, group_name
|
||||
)
|
||||
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
b, group_size, group_name
|
||||
)
|
||||
|
||||
# Independent compute that can hide both
|
||||
mm1 = torch.mm(a, a)
|
||||
mm2 = torch.mm(b, b)
|
||||
|
||||
# Wait for both
|
||||
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
|
||||
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
|
||||
|
||||
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
|
||||
|
||||
# Use fake mode to trace without executing
|
||||
with FakeTensorMode():
|
||||
a = torch.ones(4, 4, device=self.device)
|
||||
b = torch.ones(4, 4, device=self.device) * 2
|
||||
|
||||
# Trace with make_fx
|
||||
traced = make_fx(func)(a, b)
|
||||
|
||||
# Find nodes using find_nodes
|
||||
ag1, ag2 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
)
|
||||
mm1, mm2 = traced.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
|
||||
# Manually annotate hiding relationships
|
||||
hiding_annotations = {
|
||||
ag1: mm1, # mm1 hides ag1
|
||||
ag2: mm2, # mm2 hides ag2
|
||||
}
|
||||
|
||||
# Build collective info and ancestors
|
||||
collective_info = build_collective_info(traced.graph, hiding_annotations)
|
||||
node_ancestors = compute_ancestors(traced.graph)
|
||||
scheduled = OrderedSet(traced.graph.nodes)
|
||||
|
||||
# Run bucketing
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
traced.graph,
|
||||
collective_info,
|
||||
node_ancestors,
|
||||
scheduled,
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
# Verify: should have 1 bucketed collective (all_gather_into_tensor_out)
|
||||
graph_str = str(traced.graph)
|
||||
FileCheck().check_count("all_gather_into_tensor_out", 1, exactly=False).run(
|
||||
graph_str
|
||||
)
|
||||
|
||||
def test_cant_bucket_nested_hiding_intervals(self):
|
||||
"""
|
||||
Test that nested hiding intervals prevent bucketing.
|
||||
|
||||
Graph structure:
|
||||
ag1_start -> ag2_start -> mm2 (hides ag2) -> ag2_wait -> mm1 (hides ag1) -> ag1_wait
|
||||
|
||||
ag2's hiding interval is nested inside ag1's hiding interval.
|
||||
"""
|
||||
|
||||
def func(a, b):
|
||||
group_name = "0"
|
||||
group_size = 1
|
||||
|
||||
# ag1 starts first
|
||||
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
a, group_size, group_name
|
||||
)
|
||||
|
||||
# ag2 starts (inside ag1's interval)
|
||||
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
b, group_size, group_name
|
||||
)
|
||||
|
||||
# mm2 hides ag2
|
||||
mm2 = torch.mm(b[:2, :2], b[:2, :2])
|
||||
|
||||
# ag2 waits (still inside ag1's interval)
|
||||
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
|
||||
|
||||
# mm1 uses ag2's result and hides ag1
|
||||
mm1 = torch.mm(a + ag2_out[:4, :4], a)
|
||||
|
||||
# ag1 waits last
|
||||
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
|
||||
|
||||
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
|
||||
|
||||
# Use fake mode to trace without executing
|
||||
with FakeTensorMode():
|
||||
a = torch.ones(4, 4, device=self.device)
|
||||
b = torch.ones(4, 4, device=self.device) * 2
|
||||
|
||||
# Trace with make_fx
|
||||
traced = make_fx(func)(a, b)
|
||||
|
||||
# Find nodes using find_nodes
|
||||
ag1, ag2 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
)
|
||||
mm_nodes = traced.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
# mm2 is the first mm, mm1 is the second (based on graph order)
|
||||
mm2 = mm_nodes[0]
|
||||
mm1 = mm_nodes[1]
|
||||
|
||||
# Manually annotate hiding relationships
|
||||
hiding_annotations = {
|
||||
ag1: mm1, # mm1 hides ag1
|
||||
ag2: mm2, # mm2 hides ag2
|
||||
}
|
||||
|
||||
# Build collective info and ancestors
|
||||
collective_info = build_collective_info(traced.graph, hiding_annotations)
|
||||
node_ancestors = compute_ancestors(traced.graph)
|
||||
scheduled = OrderedSet(traced.graph.nodes)
|
||||
|
||||
# Run bucketing
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
traced.graph,
|
||||
collective_info,
|
||||
node_ancestors,
|
||||
scheduled,
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
# Verify: nested hiding intervals should prevent bucketing
|
||||
# Should have 2 separate all_gathers, not 1 bucketed one
|
||||
graph_str = str(traced.graph)
|
||||
FileCheck().check_count("all_gather_into_tensor", 2, exactly=False).run(
|
||||
graph_str
|
||||
)
|
||||
|
||||
@parametrize("final_mm_hidden", (True, False))
|
||||
def test_cant_bucket_ag_with_rs_hiding_interval_between(self, final_mm_hidden):
|
||||
"""
|
||||
Test that all_gathers can't bucket when a reduce_scatter's hiding interval is between them.
|
||||
|
||||
Graph structure:
|
||||
ag1_start -> mm1 (hides ag1) -> ag1_wait ->
|
||||
rs_start -> mm2 (hides rs) -> rs_wait ->
|
||||
|
||||
if final_mm_hidden:
|
||||
ag2_start -> mm3 (hides ag2) -> ag2_wait
|
||||
|
||||
if final_mm_hidden:
|
||||
Bucketing ag1 and ag2 would require moving one of them, which would break hiding relationships:
|
||||
- Moving ag2 earlier would break ag2's hiding by mm3
|
||||
- Moving ag1 later would break ag1's hiding by mm1
|
||||
- The rs hiding interval creates an obstacle between them
|
||||
|
||||
otherwise, we can bucket
|
||||
"""
|
||||
|
||||
def func(a, b, c):
|
||||
group_name = dist.distributed_c10d._get_default_group().group_name
|
||||
group_size = 1
|
||||
|
||||
# First all_gather
|
||||
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
a, group_size, group_name
|
||||
)
|
||||
mm1 = torch.mm(a, a) # hides ag1
|
||||
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
|
||||
|
||||
# Reduce scatter in between
|
||||
rs = torch.ops._c10d_functional.reduce_scatter_tensor(
|
||||
b, "sum", group_size, group_name
|
||||
)
|
||||
mm2 = torch.mm(b[:4, :4], b[:4, :4]) # hides rs
|
||||
rs_out = torch.ops._c10d_functional.wait_tensor(rs)
|
||||
|
||||
# Second all_gather
|
||||
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
c, group_size, group_name
|
||||
)
|
||||
mm3 = torch.mm(c, c) # hides ag2
|
||||
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
|
||||
|
||||
return ag1_out.sum() + rs_out.sum() + ag2_out.sum(), mm1, mm2, mm3
|
||||
|
||||
# Use fake mode to trace without executing
|
||||
with FakeTensorMode():
|
||||
a = torch.ones(4, 4, device=self.device)
|
||||
b = torch.ones(8, 4, device=self.device)
|
||||
c = torch.ones(4, 4, device=self.device)
|
||||
|
||||
# Trace with make_fx
|
||||
traced = make_fx(func)(a, b, c)
|
||||
|
||||
ag1, ag2 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
)
|
||||
(rs,) = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.reduce_scatter_tensor.default,
|
||||
)
|
||||
mm1, mm2, mm3 = traced.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
|
||||
# Manually annotate hiding relationships
|
||||
hiding_annotations = {
|
||||
ag1: mm1, # mm1 hides ag1
|
||||
# rs: mm2, # mm2 hides rs
|
||||
ag2: mm3,
|
||||
}
|
||||
if final_mm_hidden:
|
||||
hiding_annotations[rs] = mm2
|
||||
|
||||
# Build collective info and ancestors
|
||||
collective_info = build_collective_info(traced.graph, hiding_annotations)
|
||||
node_ancestors = compute_ancestors(traced.graph)
|
||||
scheduled = OrderedSet(traced.graph.nodes)
|
||||
|
||||
# Run bucketing logic to find buckets (without applying them, which would require process groups)
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
traced.graph,
|
||||
collective_info,
|
||||
node_ancestors,
|
||||
scheduled,
|
||||
)
|
||||
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
graph_str = str(traced.graph)
|
||||
|
||||
# check order of mms preserved
|
||||
FileCheck().check("%mm").check("%mm_1").check("%mm_2").run(graph_str)
|
||||
|
||||
if final_mm_hidden:
|
||||
# Should NOT bucket - 2 separate all_gathers
|
||||
# Count all_gather node names (works even when wrapped in control_deps)
|
||||
FileCheck().check_count("%all_gather_into_tensor", 2, exactly=False).run(
|
||||
graph_str
|
||||
)
|
||||
else:
|
||||
# Should bucket - 1 bucketed all_gather (all_gather_into_tensor_out)
|
||||
FileCheck().check_count(
|
||||
"%all_gather_into_tensor_out", 1, exactly=False
|
||||
).run(graph_str)
|
||||
|
||||
def test_can_bucket_all_reduce(self):
|
||||
"""
|
||||
Test that all_reduce operations CAN bucket together.
|
||||
|
||||
Graph structure:
|
||||
ar1_start -> ar2_start -> mm1 (hides ar1) -> mm2 (hides ar2) -> ar1_wait -> ar2_wait
|
||||
"""
|
||||
|
||||
def func(a, b):
|
||||
group_name = "0"
|
||||
|
||||
# Start both all_reduce operations
|
||||
ar1 = torch.ops._c10d_functional.all_reduce(a, "sum", group_name)
|
||||
ar2 = torch.ops._c10d_functional.all_reduce(b, "sum", group_name)
|
||||
|
||||
# Independent compute that can hide both
|
||||
mm1 = torch.mm(a, a)
|
||||
mm2 = torch.mm(b, b)
|
||||
|
||||
# Wait for both
|
||||
ar1_out = torch.ops._c10d_functional.wait_tensor(ar1)
|
||||
ar2_out = torch.ops._c10d_functional.wait_tensor(ar2)
|
||||
|
||||
return ar1_out.sum() + ar2_out.sum() + mm1.sum() + mm2.sum()
|
||||
|
||||
# Use fake mode to trace without executing
|
||||
with FakeTensorMode():
|
||||
a = torch.ones(4, 4, device=self.device)
|
||||
b = torch.ones(4, 4, device=self.device) * 2
|
||||
|
||||
# Trace with make_fx
|
||||
traced = make_fx(func)(a, b)
|
||||
|
||||
# Find nodes
|
||||
ar1, ar2 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.all_reduce.default,
|
||||
)
|
||||
mm1, mm2 = traced.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
|
||||
# For all_reduce, start_node == wait_node (no separate wait)
|
||||
hiding_annotations = {
|
||||
ar1: mm1,
|
||||
ar2: mm2,
|
||||
}
|
||||
|
||||
# Build collective info
|
||||
collective_info = build_collective_info(traced.graph, hiding_annotations)
|
||||
node_ancestors = compute_ancestors(traced.graph)
|
||||
scheduled = OrderedSet(traced.graph.nodes)
|
||||
|
||||
# Run bucketing
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
traced.graph,
|
||||
collective_info,
|
||||
node_ancestors,
|
||||
scheduled,
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
# Verify: should have 1 bucketed all_reduce
|
||||
# After bucketing, there should be only one all_reduce node (the bucketed one)
|
||||
graph_str = str(traced.graph)
|
||||
FileCheck().check_count("%all_reduce", 1, exactly=True).check_count(
|
||||
"%mm", 2
|
||||
).run(graph_str)
|
||||
|
||||
def test_can_bucket_multidtype_collectives(self):
|
||||
"""
|
||||
Test that all_gathers with different dtypes CAN bucket together.
|
||||
|
||||
Graph structure:
|
||||
ag1_float32 -> mm1 (hides ag1) -> ag1_wait
|
||||
ag2_bfloat16 -> mm2 (hides ag2) -> ag2_wait
|
||||
"""
|
||||
|
||||
def func(a, b):
|
||||
group_name = "0"
|
||||
group_size = 1
|
||||
|
||||
# Start both collectives with different dtypes
|
||||
ag1 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
a,
|
||||
group_size,
|
||||
group_name, # float32
|
||||
)
|
||||
ag2 = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
b,
|
||||
group_size,
|
||||
group_name, # bfloat16
|
||||
)
|
||||
|
||||
# Independent compute that can hide both
|
||||
mm1 = torch.mm(a, a)
|
||||
mm2 = torch.mm(b.float(), b.float())
|
||||
|
||||
# Wait for both
|
||||
ag1_out = torch.ops._c10d_functional.wait_tensor(ag1)
|
||||
ag2_out = torch.ops._c10d_functional.wait_tensor(ag2)
|
||||
|
||||
return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum()
|
||||
|
||||
# Use fake mode to trace without executing
|
||||
with FakeTensorMode():
|
||||
a = torch.ones(4, 4, device=self.device, dtype=torch.float32)
|
||||
b = torch.ones(4, 4, device=self.device, dtype=torch.bfloat16)
|
||||
|
||||
# Trace with make_fx
|
||||
traced = make_fx(func)(a, b)
|
||||
|
||||
# Find nodes using find_nodes
|
||||
ag1, ag2 = traced.graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch.ops._c10d_functional.all_gather_into_tensor.default,
|
||||
)
|
||||
mm_nodes = traced.graph.find_nodes(
|
||||
op="call_function", target=torch.ops.aten.mm.default
|
||||
)
|
||||
mm1 = mm_nodes[0]
|
||||
mm2 = mm_nodes[1]
|
||||
|
||||
# Manually annotate hiding relationships
|
||||
hiding_annotations = {
|
||||
ag1: mm1, # mm1 hides ag1
|
||||
ag2: mm2, # mm2 hides ag2
|
||||
}
|
||||
|
||||
# Build collective info and ancestors
|
||||
collective_info = build_collective_info(traced.graph, hiding_annotations)
|
||||
node_ancestors = compute_ancestors(traced.graph)
|
||||
scheduled = OrderedSet(traced.graph.nodes)
|
||||
|
||||
# Run bucketing with multidtype mode
|
||||
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
|
||||
OverlapPreservingBucketer,
|
||||
)
|
||||
|
||||
bucketer = OverlapPreservingBucketer(
|
||||
traced.graph,
|
||||
collective_info,
|
||||
node_ancestors,
|
||||
scheduled,
|
||||
bucket_mode="custom_ops_multidtype",
|
||||
)
|
||||
bucketer.bucket_collectives()
|
||||
|
||||
# Verify: should have 1 bucketed collective (all_gather_into_tensor_out)
|
||||
# even though dtypes are different
|
||||
graph_str = str(traced.graph)
|
||||
FileCheck().check_count("all_gather_into_tensor_out", 1, exactly=False).run(
|
||||
graph_str
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -363,6 +363,40 @@ class FxGraphRunnableTest(TestCase):
|
||||
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
def test_metrics_context(self):
|
||||
"""
|
||||
When TORCH_COMPILE_DEBUG is set, provenance_tracking_level is set to 1, and
|
||||
the generated fx_graph_runnable crashed with,
|
||||
RuntimeError: Cannot add inductor_provenance outside of a MetricsContext
|
||||
"""
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
def f(x):
|
||||
return x * 2 + 1
|
||||
|
||||
# Enable provenance tracking to trigger the code path that adds metrics
|
||||
with inductor_config.patch(
|
||||
{"trace.enabled": True, "trace.provenance_tracking_level": 1}
|
||||
):
|
||||
x = torch.randn(4, 4)
|
||||
torch.compile(f)(x)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@torch._dynamo.config.patch(assume_static_by_default=False)
|
||||
def test_dynamic_expression(self):
|
||||
"""
|
||||
Test not emitting something like "s27*s53**2 = 36"
|
||||
"""
|
||||
|
||||
def f(x):
|
||||
return torch.ops.aten._adaptive_avg_pool2d(
|
||||
x, (6, 6)
|
||||
), torch.ops.aten._adaptive_avg_pool2d(x + 1, (2, 5))
|
||||
|
||||
x = torch.randn(2, 4, 16, 16)
|
||||
torch.compile(f)(x)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user