Compare commits

..

38 Commits

Author SHA1 Message Date
19e09ed25d [Dynamo] added warning message for tracing lru_cache wrapped functions #153744 2025-05-19 11:35:18 -07:00
c54b9f2969 [Monitoring] Add util for linux build (#153456)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153456
Approved by: https://github.com/huydhn
2025-05-19 17:28:17 +00:00
be36bacdaa [pytorch] Delete TorchScript based Android demo app and point user to ExecuTorch (#153767)
Summary: A retry of #153656. This time start from co-dev to make sure we capture internal signals.

Test Plan: Rely on CI jobs.

Differential Revision: D74911818

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153767
Approved by: https://github.com/kirklandsign, https://github.com/cyyever, https://github.com/Skylion007
2025-05-19 17:20:36 +00:00
6487ea30b3 [c10d] Fix new_subgroups(group=) bug (#153798)
Summary: The bug, introduced in https://github.com/pytorch/pytorch/pull/152765, was caused by passing the `group` parameter to the `get_rank()` function, which caused the function to return the rank of the entire group instead of the rank of the current process. The fix involves removing the `group` parameter from the `get_rank()` function call.

Test Plan: contbuild & OSS CI

Differential Revision: D74964213

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153798
Approved by: https://github.com/Skylion007
2025-05-19 17:01:10 +00:00
b0e5402377 Revert "Recheck autotune cache on static cuda launcher load (#153565)"
This reverts commit 02af4e88e4e76309672dbc9b5970ae630df525c7.

Reverted https://github.com/pytorch/pytorch/pull/153565 on behalf of https://github.com/malfet due to Looks like it broke ROCM, see ee72c53c88/1 ([comment](https://github.com/pytorch/pytorch/pull/153565#issuecomment-2891673913))
2025-05-19 16:52:48 +00:00
ee72c53c88 Enable ruff check for all ipynb files (#153820)
Fixes #146411, following #148654

After test, seems this could be enabled for all ipynb file.

```bash
lintrunner --take RUFF --all-files
Warning: Could not find a lintrunner config at: '.lintrunner.private.toml'. Continuing without using configuration file.
ok No lint issues.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153820
Approved by: https://github.com/Skylion007
2025-05-19 16:45:26 +00:00
ed5f4a4fa8 Replace size() checks with empty() (#153805)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153805
Approved by: https://github.com/nareshrajkumar866, https://github.com/Skylion007

Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
2025-05-19 16:20:57 +00:00
0ec8fe46d7 cleanup, refactor and add missing self._dde_suppressed checks (#152657)
so two things other than cleanups and refactoring
1) do not use propagate_real_tensors to resolve eval under guard_or_true/guard_or_false .
2) do not guard for dimensions of type  DimDynamic.OBLIVIOUS_SIZE under guard_or_true/guard_or_false .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152657
Approved by: https://github.com/pianpwk
2025-05-19 16:15:14 +00:00
dccd19c2ef [Inductor] Construct subgraph with benchmarking args not example_inputs (#153753)
If the inputs to a subgraph has FlexibleLayout, the subgraph does not currently freeze the layouts here. Therefore, the `example_inputs` generated might not be consistent in layout with the `args` based in for benchmarking

Differential Revision: [D74900879](https://our.internmc.facebook.com/intern/diff/D74900879/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153753
Approved by: https://github.com/eellison
2025-05-19 15:58:40 +00:00
7a46f4bde0 Enable accelerator to perform streaming backward (#153412)
Also see https://github.com/pytorch/pytorch/pull/142097
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153412
Approved by: https://github.com/albanD
ghstack dependencies: #151079
2025-05-19 15:52:42 +00:00
c5cba39d46 Improve torch.ops typing (#153558)
Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.

Decisions made along the way:

1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.

The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.

Test plan: CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153558
Approved by: https://github.com/rec, https://github.com/Skylion007, https://github.com/cyyever
2025-05-19 14:52:32 +00:00
3cd5b3b1e7 [AOTI] Skip a rocm test (#153828)
Summary: Skip test_aot_inductor_package.test_compile_after_package. https://github.com/pytorch/pytorch/pull/150739 added an opt-in feature which doesn't work for rocm yet.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153828
Approved by: https://github.com/malfet
2025-05-19 14:13:19 +00:00
02af4e88e4 Recheck autotune cache on static cuda launcher load (#153565)
When loading statically launchable triton kernels from FxGraphCache, since we don't instantiate a CachingAutotuner like we do normally, we need to recheck the autotune cache based on the existing compile results. If we get a hit, we take the compile result whose config matches the best config.

Sometimes, the best config will have been from coordinate descent tuning. In this case, FxGraphCache today does not cache the resulting triton kernel, neither with static or without static cuda launcher. This is because coordinate descent tuning happens at runtime, and if the best config happens to not be one of the precompiled configs.

Test Plan:
New unit test that failed before

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153565
Approved by: https://github.com/aorenste
2025-05-19 12:50:22 +00:00
c45515c2ed Update slow tests (#153815)
This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/weekly.yml).
Update the list of slow tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153815
Approved by: https://github.com/pytorchbot
2025-05-19 11:15:25 +00:00
4f1a52fba4 [xla hash update] update the pinned xla hash (#153816)
This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml).
Update the pinned xla hash.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153816
Approved by: https://github.com/pytorchbot
2025-05-19 11:05:51 +00:00
f3daedb263 [BE]: Remove redundant copy (#153629)
Add typing and remove redundant copy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153629
Approved by: https://github.com/cyyever, https://github.com/albanD
2025-05-19 08:25:20 +00:00
5506baa4ed Refactoring FSDP2 (_composable/fsdp) test cases to be device agnostic (#149848)
The motivation for this PR is refactor existing test cases in the folder test/distributed/_composable/fsdp/ or fsdp2(as referred to in torch titan) to be device agnostic such that any accelerator type is supported (for eg. CUDA, HPU, XPU etc)

The changes are in line with previously merged changes for fsdp (present in the folder test/distributed/fsdp/ ) test cases: https://github.com/pytorch/pytorch/pull/139184/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149848
Approved by: https://github.com/kwen2501, https://github.com/guangyey
2025-05-19 05:46:51 +00:00
6f835a4769 [amd] fix tunableop gemm (#153764)
Summary: Tunableop on AMD has perf regression for a while. It turns out that the tunableop code path will first run tuned GEMM and then run heuristics GEMM (so run two GEMMs...)....

Test Plan:
```
CUDA_VISIBLE_DEVICES=0 buck test @//mode/opt-amd-gpu -c fbcode.rocm_arch=mi300 -c fbcode.rocm_ck_rtz=true fbcode//accelerators/workloads/microbench/RE:test_emu_v1p4 -- --exact 'accelerators/workloads/microbench/RE:test_emu_v1p4 - test_gemm (accelerators.workloads.microbench.RE.test_emu_v1p4.EMUv1p4PerfTest)' --run-disabled
```

Before the diff
```
  File "/data/users/mxz/fbsource/buck-out/v2/gen/fbcode/ecc11ed52295855f/accelerators/workloads/microbench/RE/__test_emu_v1p4__/test_emu_v1p4#link-tree/accelerators/workloads/microbench/RE/test_emu_v1p4.py", line 47, in test_gemm
    self.assertTrue(result < AMD_GEMM_BASELINE * AMD_GEMM_THRESHOLD)

Buck UI: https://www.internalfb.com/buck2/b4b8dfca-0301-4c5d-83d6-d866d840c42d
Test UI: https://www.internalfb.com/intern/testinfra/testrun/14355223896396807
Network: Up: 10MiB  Down: 1.9GiB  (reSessionID-23b213fe-a460-4788-86c6-a52343ff10f4)
Loading targets.   Remaining      0/5144                                      93161 dirs read, 753263 targets declared
Analyzing targets. Remaining      0/70523                                     2837379 actions, 3262810 artifacts declared
Executing actions. Remaining      0/472286                                    217:26:58.1s exec time total
Command: test.     Finished 122 local, 522 remote, 199785 cache (99% hit)     211:26:30.5s exec time cached (97%)
Time elapsed: 12:50.2s
Test execution completed but the tests failed
Tests finished: Pass 0. Fail 1. Fatal 0. Skip 0. Build failure 0
1 TESTS FAILED
  ✗ accelerators/workloads/microbench/RE:test_emu_v1p4 - test_gemm (accelerators.workloads.microbench.RE.test_emu_v1p4.EMUv1p4PerfTest)

Run $ fdb buck test <args> to debug accelerators/workloads/microbench/RE:test_emu_v1p4 - test_gemm (accelerators.workloads.microbench.RE.test_emu_v1p4.EMUv1p4PerfTest)
      ^^^ just prefix your previous command! ($ fdb !!)
Learn more at https://fburl.com/fdb
```

After the diff
```
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 0. Build failure 0
```

Reviewed By: henryoier, henryhu6

Differential Revision: D74910115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153764
Approved by: https://github.com/yangsiyu007, https://github.com/xw285cornell
2025-05-19 04:07:48 +00:00
2ade886412 [XPU] [Windows] Auto turn on kineto XPU build when compiler version support. (#153681)
Since SYCL compiler 20250101, it will remove dependency of level zero header. We can turn on kineto XPU by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153681
Approved by: https://github.com/chuanqi129, https://github.com/cyyever, https://github.com/EikanWang
2025-05-19 03:07:14 +00:00
1bc5762495 [Intel GPU][Inductor] Fallback embedding_dense_backward on XPU (#151637)
Reopen #146888, now the modification only affects xpu device. We do not  want to decompose embedding_dense_backward for torch.compile. Current XPU devices have hardware limitations on atomic ops. Fallback to eager and we can use sort to implement this op. hf_T5 amp bf16 training in torchbench can get 2x improvement on Max 1550. ~~I also align with cuda on gelu decomposition in _addmm_activation~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151637
Approved by: https://github.com/guangyey, https://github.com/etaf, https://github.com/jansel, https://github.com/EikanWang
2025-05-19 02:19:37 +00:00
74d0300804 Change unsafe_marked_cacheable_functions to a dictionary, so that you can specify a static cache key (#152486)
Fixes https://github.com/pytorch/pytorch/issues/152434

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152486
Approved by: https://github.com/oulgen
2025-05-19 02:16:33 +00:00
694748dd9d [MPSInductor] Fix conv_transpose channels last (#153787)
Regardless of the input layout, transposed convolution always returns contiguous tensor on MPS
Add test to validate that
This fixes torch.compile for SegmentAnything network

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153787
Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #153786
2025-05-19 02:01:48 +00:00
6fe5d9215f [EZ][MPS] Enable rsub op (#153786)
Nothing really to enable, just add it to native functions, TensorIterator abstraction takes care of the rest
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153786
Approved by: https://github.com/cyyever, https://github.com/Skylion007, https://github.com/dcci
2025-05-19 02:01:48 +00:00
a2d0ef242d [AOTI] Embed cubin files into .so (#150739)
Summary: Embed cubin files so AOTI is one step closer to generate a single binary. Controlled by a flag and off as default.

Differential Revision: [D72535357](https://our.internmc.facebook.com/intern/diff/D72535357)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150739
Approved by: https://github.com/angelayi
2025-05-19 01:11:46 +00:00
cyy
a8986963da Fix some CMake issues (#153686)
These issues were discovered when trying CMake 3.27:
1. set C++ language on HIP sources.
2. add missing link to gtest_main.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153686
Approved by: https://github.com/Skylion007
2025-05-19 00:31:34 +00:00
75eb2f3ff6 Revert "[Dynamo] added warning message for tracing lru_cache wrapped functions (#153744)"
This reverts commit aac30ef50366b03f0ef2d1e770f45a3465f6ea66.

Reverted https://github.com/pytorch/pytorch/pull/153744 on behalf of https://github.com/jeanschmidt due to Need to revert as it is breaking internal signals: [D74935585](https://www.internalfb.com/diff/D74935585) ([comment](https://github.com/pytorch/pytorch/pull/153744#issuecomment-2889187038))
2025-05-18 20:13:00 +00:00
cb57b19c3a [ATen-CPU] Use math.h for GeLU as well as cmath (#153742)
Summary:
## Context

See https://github.com/pytorch/pytorch/pull/149164 for more context.

Originally, this fix worked but more recently including `cmath` by itself no longer provides access to math constants on Windows platforms. I found that including `math.h` resolves this.

I'm not sure exactly what changed, but this PR updates the header to just use both includes fix the symbols not being found. It might be a bug with a recent Windows update perhaps?

Test Plan:
CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153742
Approved by: https://github.com/swolchok, https://github.com/Skylion007
2025-05-18 19:06:45 +00:00
aa84c037f0 FakeTensorMode dispatch shouldn't include bypass in exception context (#153780)
In the FakeTensor cache when we get a bypass exception while computing the cache key (call this exc_1) we need to dispatch to the original operation.

It's possible for the dispatch to the original operation to get its own exception which we want to bubble up to the caller (call this exc_2).

If we directly dispatch from within the handler for exc_1 then exc_2 will have a `__context__` of exc_1 - which can cause deviations between cached and non-cached behavior - so we need to be a bit careful when we call the dispatch.

Testing:
test_aotdispatch.py::TestAOTExport::test_aot_export_predispatch_outdtype fails before this change and passes after.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153780
Approved by: https://github.com/oulgen
2025-05-18 17:21:46 +00:00
68034198e5 [HOP] Mutation and alias rework (#146658)
This PR reworks the way the input mutations and various aliases are checked

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146658
Approved by: https://github.com/ydwu4
2025-05-18 08:05:22 +00:00
0e805aad7f [ONNX] Support float4 (#151069)
- Support exporting float4 models (note: currently we use IR version 10 universally in the exporter, which does not include float 4 support. Eventually when onnx runtime and the ecosystem moves to support the new IR version 11 we should bump our version to 11 in the exporter as well)
- The shape of the type is set according to https://github.com/pytorch/pytorch/pull/148791#discussion_r2038704986 (added last dim with size 2)
- Use ml_dtypes types when converting to numpy for consistency with ONNX IR

Fix https://github.com/pytorch/pytorch/issues/150202

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151069
Approved by: https://github.com/titaiwangms
2025-05-18 03:19:35 +00:00
8568dbce1d [inductor] Clean typing in codegen/common.py and codecache.py (#150767)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150767
Approved by: https://github.com/aorenste
2025-05-17 13:56:50 +00:00
27f7b65a69 [BE] Ensure generated stub files by gen_pyi are properly formatted (#150730)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150730
Approved by: https://github.com/aorenste
2025-05-17 12:30:40 +00:00
7ebea09986 [Cutlass] Enable fusion with FusedSchedulerNodes (#153588)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153588
Approved by: https://github.com/eellison
ghstack dependencies: #152815
2025-05-17 12:29:10 +00:00
f604732e2e [Cutlass] E2E Tests for EVT (#152815)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152815
Approved by: https://github.com/henrylhtsang, https://github.com/eellison
2025-05-17 12:29:10 +00:00
b4fb801b2d [export] Move PT2 constants to torch::_export (#153206)
Test Plan:
`buck2 test //sigmoid/...`
https://www.internalfb.com/intern/testinfra/testrun/1970325119807758

Differential Revision: D74417085

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153206
Approved by: https://github.com/zhxchen17, https://github.com/dolpm
2025-05-17 08:21:59 +00:00
40339c1e99 Revert "[CUDA][cuBLAS][cuBLASLt] avoid polluting prefer cuBLAS/Lt setting across tests (#153655)"
This reverts commit 3bde364996d53571a9fb799f5951a203a352ed18.

Reverted https://github.com/pytorch/pytorch/pull/153655 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail a test in trunk ([comment](https://github.com/pytorch/pytorch/pull/153655#issuecomment-2888212597))
2025-05-17 08:11:54 +00:00
9b2a45ac7d Refactor torch/utils/data/datapipes/gen_pyi.py with torchgen (#150626)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150626
Approved by: https://github.com/aorenste
2025-05-17 06:21:41 +00:00
eqy
e802b29ed4 [SDPA][EZ] Abate narrowing conversion warning spam in flash_api.cpp (#153643)
for messages like
```/workspace/pytorch/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp:1396:38: warning: narrowing conversion of ‘(char)(& q)->at::Tensor::<anonymous>.at::TensorBase::get_device()’ from ‘char’ to ‘c10::DeviceIndex’ {aka ‘signed ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153643
Approved by: https://github.com/Skylion007
2025-05-17 02:07:35 +00:00
189 changed files with 2785 additions and 3612 deletions

View File

@ -23,6 +23,12 @@ inputs:
type: string
description: 'the job name of the test'
required: True
artifact_prefix:
type: string
description: |
'the prefix of the raw utilization data, for data stored in zip file, this is the prefix of the parent zip file'
default: ""
required: False
runs:
using: composite
@ -35,6 +41,7 @@ runs:
echo "workflow_Name: ${{inputs.workflow_name}}"
echo "job_id: ${{inputs.job_id}}"
echo "job_name: ${{inputs.job_name}}"
echo "artifact_prefix: ${{inputs.artifact_prefix}}"
- uses: nick-fields/retry@v3.0.0
name: Setup dependencies
with:
@ -53,4 +60,5 @@ runs:
--workflow-name "${{inputs.workflow_name}}" \
--workflow-run-attempt "${{inputs.workflow_attempt}}" \
--job-id "${{inputs.job_id}}" \
--job-name "${{inputs.job_name}}"
--job-name "${{inputs.job_name}}" \
--artifact-prefix "${{inputs.artifact_prefix}}"

View File

@ -1 +1 @@
8d9e34b352af09c81ff8df448fd27f9c4aae1382
edc1a882d872dd7f1362e4312fd045a1d81b3355

View File

@ -74,6 +74,24 @@ on:
Overwrite the number of jobs to use for the build
required: false
type: string
disable-monitor:
description: |
Disable utilization monitoring for build job
required: false
type: boolean
default: false
monitor-log-interval:
description: |
Set the interval for the monitor script to log utilization.
required: false
type: number
default: 5
monitor-data-collect-interval:
description: |
Set the interval for the monitor script to collect data.
required: false
type: number
default: 1
secrets:
HUGGING_FACE_HUB_TOKEN:
@ -176,6 +194,27 @@ jobs:
selected-test-configs: ${{ inputs.selected-test-configs }}
job-name: ${{ steps.get-job-id.outputs.job-name }}
- name: Start monitoring script
id: monitor-script
if: ${{ !inputs.disable-monitor }}
shell: bash
continue-on-error: true
env:
JOB_ID: ${{ steps.get-job-id.outputs.job-id }}
JOB_NAME: ${{ steps.get-job-id.outputs.job-name }}
WORKFLOW_NAME: ${{ github.workflow }}
WORKFLOW_RUN_ID: ${{github.run_id}}
MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }}
MONITOR_DATA_COLLECT_INTERVAL: ${{ inputs.monitor-data-collect-interval }}
run: |
mkdir -p ../../usage_logs
python3 -m pip install psutil==5.9.1 dataclasses_json==0.6.7
python3 -m tools.stats.monitor \
--log-interval "$MONITOR_LOG_INTERVAL" \
--data-collect-interval "$MONITOR_DATA_COLLECT_INTERVAL" \
> "../../usage_logs/usage_log_build_${JOB_ID}.txt" 2>&1 &
echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}"
- name: Download pytest cache
uses: ./.github/actions/pytest-cache-download
continue-on-error: true
@ -280,6 +319,15 @@ jobs:
END_TIME=$(date +%s)
echo "build_time=$((END_TIME - START_TIME))" >> "$GITHUB_OUTPUT"
- name: Stop monitoring script
if: ${{ always() && steps.monitor-script.outputs.monitor-script-pid }}
shell: bash
continue-on-error: true
env:
MONITOR_SCRIPT_PID: ${{ steps.monitor-script.outputs.monitor-script-pid }}
run: |
kill "$MONITOR_SCRIPT_PID"
- name: Archive artifacts into zip
if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped'
run: |
@ -304,6 +352,25 @@ jobs:
if-no-files-found: error
path: artifacts.zip
- name: copy logs
shell: bash
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor && inputs.build-environment != 'linux-s390x-binary-manywheel'}}
continue-on-error: true
run: |
rm -f ./usage_logs
mkdir -p ./usage_logs
cp ../../usage_logs/usage_log_build_*.txt ./usage_logs/
- name: Upload raw usage log to s3
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor && inputs.build-environment != 'linux-s390x-binary-manywheel'}}
uses: seemethere/upload-artifact-s3@v5
with:
s3-prefix: |
${{ github.repository }}/${{ github.run_id }}/${{ github.run_attempt }}/artifact
retention-days: 14
if-no-files-found: warn
path: usage_logs/usage_log_build_*.txt
- name: Upload sccache stats
if: steps.build.outcome != 'skipped' && inputs.build-environment != 'linux-s390x-binary-manywheel'
uses: ./.github/actions/upload-sccache-stats
@ -311,6 +378,18 @@ jobs:
github-token: ${{ secrets.GITHUB_TOKEN }}
build-time: ${{ steps.build.outputs.build_time }}
- name: Upload utilization stats
if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
continue-on-error: true
uses: ./.github/actions/upload-utilization-stats
with:
job_id: ${{ steps.get-job-id.outputs.job-id }}
job_name: ${{ steps.get-job-id.outputs.job-name }}
workflow_name: ${{ github.workflow }}
workflow_run_id: ${{github.run_id}}
workflow_attempt: ${{github.run_attempt}}
artifact_prefix: usage_log_build_${{ steps.get-job-id.outputs.job-id }}
- name: Teardown Linux
uses: pytorch/test-infra/.github/actions/teardown-linux@main
if: always() && inputs.build-environment != 'linux-s390x-binary-manywheel'

1
.gitignore vendored
View File

@ -47,6 +47,7 @@ docs/source/generated/
docs/source/compile/generated/
log
usage_log.txt
usage_log*
test-reports/
test/*.bak
test/**/*.bak

View File

@ -1521,7 +1521,7 @@ code = 'RUFF'
include_patterns = [
'**/*.py',
'**/*.pyi',
'torch/utils/data/*.ipynb',
'**/*.ipynb',
'pyproject.toml',
]
exclude_patterns = [

View File

@ -2,7 +2,9 @@
## Demo applications and tutorials
Demo applications with code walk-through can be find in [this github repo](https://github.com/pytorch/android-demo-app).
Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions.
## Publishing
@ -119,8 +121,6 @@ We also have to add all transitive dependencies of our aars.
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
(In case of using maven dependencies they are added automatically from `pom.xml`).
You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
## Linking to prebuilt libtorch library from gradle dependency
In some cases, you may want to use libtorch from your android native build.
@ -202,7 +202,7 @@ find_library(FBJNI_LIBRARY fbjni
NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY})
${PYTORCH_LIBRARY}
${FBJNI_LIBRARY})
```
@ -233,8 +233,6 @@ void loadAndForwardModel(const std::string& modelPath) {
To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28)
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
## PyTorch Android API Javadoc
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).

View File

@ -1,30 +0,0 @@
#!/bin/bash
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
echo "PYTORCH_DIR:$PYTORCH_DIR"
source "$PYTORCH_ANDROID_DIR/common.sh"
check_android_sdk
check_gradle
parse_abis_list "$@"
build_android
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
# systemProp.http.proxyHost=...
# systemProp.http.proxyPort=8080
# systemProp.https.proxyHost=...
# systemProp.https.proxyPort=8080
if [ "$CUSTOM_ABIS_LIST" = true ]; then
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
else
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
fi
find $PYTORCH_ANDROID_DIR -type f -name *apk
find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r "

View File

@ -1,32 +0,0 @@
#!/bin/bash
###############################################################################
# This script tests the custom selective build flow for PyTorch Android, which
# optimizes library size by only including ops used by a specific model.
###############################################################################
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR="${PYTORCH_DIR}/android"
BUILD_ROOT="${PYTORCH_DIR}/build_pytorch_android_custom"
source "${PYTORCH_ANDROID_DIR}/common.sh"
prepare_model_and_dump_root_ops() {
cd "${BUILD_ROOT}"
MODEL="${BUILD_ROOT}/MobileNetV2.pt"
ROOT_OPS="${BUILD_ROOT}/MobileNetV2.yaml"
python "${PYTORCH_ANDROID_DIR}/test_app/make_assets_custom.py"
cp "${MODEL}" "${PYTORCH_ANDROID_DIR}/test_app/app/src/main/assets/mobilenet2.pt"
}
# Start building
mkdir -p "${BUILD_ROOT}"
check_android_sdk
check_gradle
parse_abis_list "$@"
prepare_model_and_dump_root_ops
SELECTED_OP_LIST="${ROOT_OPS}" build_android
# TODO: change this to build test_app instead
$GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean assembleRelease

View File

@ -3,4 +3,3 @@ include ':app', ':pytorch_android', ':pytorch_android_torchvision', ':pytorch_ho
project(':pytorch_android_torchvision').projectDir = file('pytorch_android_torchvision')
project(':pytorch_host').projectDir = file('pytorch_android/host')
project(':test_app').projectDir = file('test_app/app')

View File

@ -1,9 +0,0 @@
local.properties
**/*.iml
.gradle
gradlew*
gradle/wrapper
.idea/*
.DS_Store
build
.externalNativeBuild

View File

@ -1,38 +0,0 @@
cmake_minimum_required(VERSION 3.5)
set(PROJECT_NAME pytorch_testapp_jni)
project(${PROJECT_NAME} CXX)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
set(CMAKE_VERBOSE_MAKEFILE ON)
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
set(pytorch_testapp_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
message(STATUS "ANDROID_STL:${ANDROID_STL}")
file(GLOB pytorch_testapp_SOURCES
${pytorch_testapp_cpp_DIR}/pytorch_testapp_jni.cpp
)
add_library(${PROJECT_NAME} SHARED
${pytorch_testapp_SOURCES}
)
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
target_compile_options(${PROJECT_NAME} PRIVATE
-fexceptions
)
set(BUILD_SUBDIR ${ANDROID_ABI})
target_include_directories(${PROJECT_NAME} PRIVATE
${PYTORCH_INCLUDE_DIRS}
)
find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY}
log)

View File

@ -1,190 +0,0 @@
apply plugin: 'com.android.application'
repositories {
jcenter()
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
flatDir {
dirs 'aars'
}
}
android {
configurations {
extractForNativeBuild
}
compileOptions {
sourceCompatibility 1.8
targetCompatibility 1.8
}
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
applicationId "org.pytorch.testapp"
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 1
versionName "1.0"
ndk {
abiFilters ABI_FILTERS.split(",")
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//externalNativeBuild {
// cmake {
// abiFilters ABI_FILTERS.split(",")
// arguments "-DANDROID_STL=c++_shared"
// }
//}
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet2q.pt\"")
buildConfigField("String", "LOGCAT_TAG", "@string/app_name")
buildConfigField("long[]", "INPUT_TENSOR_SHAPE", "new long[]{1, 3, 224, 224}")
buildConfigField("boolean", "NATIVE_BUILD", 'false')
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'false')
buildConfigField(
"int",
"BUILD_LITE_INTERPRETER",
System.env.BUILD_LITE_INTERPRETER != null ? System.env.BUILD_LITE_INTERPRETER : "1"
)
addManifestPlaceholders([APP_NAME: "@string/app_name", MAIN_ACTIVITY: "org.pytorch.testapp.MainActivity"])
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//externalNativeBuild {
// cmake {
// path "CMakeLists.txt"
// }
//}
flavorDimensions "model", "build", "activity"
productFlavors {
mnet {
dimension "model"
applicationIdSuffix ".mnet"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2.ptl\"")
addManifestPlaceholders([APP_NAME: "MNET"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet\"")
}
// NB: This is not working atm https://github.com/pytorch/pytorch/issues/102966
mnetVulkan {
dimension "model"
applicationIdSuffix ".mnet_vulkan"
buildConfigField("String", "MODULE_ASSET_NAME", "\"mobilenet_v2_vulkan.ptl\"")
buildConfigField("boolean", "USE_VULKAN_DEVICE", 'true')
addManifestPlaceholders([APP_NAME: "MNET_VULKAN"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-mnet-vulkan\"")
}
resnet18 {
dimension "model"
applicationIdSuffix ".resnet18"
buildConfigField("String", "MODULE_ASSET_NAME", "\"resnet18.ptl\"")
addManifestPlaceholders([APP_NAME: "RN18"])
buildConfigField("String", "LOGCAT_TAG", "\"pytorch-resnet18\"")
}
local {
dimension "build"
}
nightly {
dimension "build"
}
aar {
dimension "build"
}
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//nativeBuild {
// dimension "build"
// buildConfigField("boolean", "NATIVE_BUILD", "true")
//}
camera {
dimension "activity"
addManifestPlaceholders([MAIN_ACTIVITY: "org.pytorch.testapp.CameraActivity"])
}
base {
dimension "activity"
sourceSets {
main {
java {
exclude 'org/pytorch/testapp/CameraActivity.java'
}
}
}
}
}
packagingOptions {
doNotStrip '**.so'
}
// Filtering for CI
if (!testAppAllVariantsEnabled.toBoolean()) {
variantFilter { variant ->
def names = variant.flavors*.name
if (names.contains("nightly")
|| names.contains("camera")
|| names.contains("aar")
|| names.contains("nativeBuild")) {
setIgnore(true)
}
}
}
}
tasks.all { task ->
// Disable externalNativeBuild for all but nativeBuild variant
if (task.name.startsWith('externalNativeBuild')
&& !task.name.contains('NativeBuild')) {
task.enabled = false
}
}
dependencies {
implementation 'com.android.support:appcompat-v7:28.0.0'
implementation 'com.facebook.soloader:nativeloader:0.10.5'
localImplementation project(':pytorch_android')
localImplementation project(':pytorch_android_torchvision')
// Commented due to dependency on local copy of pytorch_android aar to aars folder
//nativeBuildImplementation(name: 'pytorch_android-release', ext: 'aar')
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')
nightlyImplementation 'org.pytorch:pytorch_android:2.2.0-SNAPSHOT'
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:2.2.0-SNAPSHOT'
aarImplementation(name:'pytorch_android', ext:'aar')
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
aarImplementation 'com.facebook.soloader:nativeloader:0.10.5'
aarImplementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
def camerax_version = "1.0.0-alpha05"
cameraImplementation "androidx.camera:camera-core:$camerax_version"
cameraImplementation "androidx.camera:camera-camera2:$camerax_version"
cameraImplementation 'com.google.android.material:material:1.0.0-beta01'
}
task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}

View File

@ -1,27 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="org.pytorch.testapp">
<application
android:allowBackup="true"
android:label="${APP_NAME}"
android:supportsRtl="true"
android:theme="@style/AppTheme">
<activity android:name="${MAIN_ACTIVITY}">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
<uses-permission android:name="android.permission.CAMERA" />
<!--
Permissions required by the Snapdragon Profiler to collect GPU metrics.
-->
<uses-permission android:name="android.permission.INTERNET" />
<uses-permission android:name="android.permission.WRITE_EXTERNAL_STORAGE" />
</manifest>

View File

@ -1,3 +0,0 @@
*
*/
!.gitignore

View File

@ -1,77 +0,0 @@
#include <android/log.h>
#include <pthread.h>
#include <unistd.h>
#include <cassert>
#include <cmath>
#include <vector>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "PyTorchTestAppJni", __VA_ARGS__)
#define ALOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, "PyTorchTestAppJni", __VA_ARGS__)
#include "jni.h"
#include <torch/script.h>
namespace pytorch_testapp_jni {
namespace {
template <typename T>
void log(const char* m, T t) {
std::ostringstream os;
os << t << std::endl;
ALOGI("%s %s", m, os.str().c_str());
}
struct JITCallGuard {
c10::InferenceMode guard;
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace
static void loadAndForwardModel(JNIEnv* env, jclass, jstring jModelPath) {
const char* modelPath = env->GetStringUTFChars(jModelPath, 0);
assert(modelPath);
// To load torchscript model for mobile we need set these guards,
// because mobile build doesn't support features like autograd for smaller
// build size which is placed in `struct JITCallGuard` in this example. It may
// change in future, you can track the latest changes keeping an eye in
// android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp
JITCallGuard guard;
torch::jit::Module module = torch::jit::load(modelPath);
module.eval();
torch::Tensor t = torch::randn({1, 3, 224, 224});
log("input tensor:", t);
c10::IValue t_out = module.forward({t});
log("output tensor:", t_out);
env->ReleaseStringUTFChars(jModelPath, modelPath);
}
} // namespace pytorch_testapp_jni
JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
return JNI_ERR;
}
jclass c =
env->FindClass("org/pytorch/testapp/LibtorchNativeClient$NativePeer");
if (c == nullptr) {
return JNI_ERR;
}
static const JNINativeMethod methods[] = {
{"loadAndForwardModel",
"(Ljava/lang/String;)V",
(void*)pytorch_testapp_jni::loadAndForwardModel},
};
int rc = env->RegisterNatives(
c, methods, sizeof(methods) / sizeof(JNINativeMethod));
if (rc != JNI_OK) {
return rc;
}
return JNI_VERSION_1_6;
}

View File

@ -1,214 +0,0 @@
package org.pytorch.testapp;
import android.Manifest;
import android.content.pm.PackageManager;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.SystemClock;
import android.util.Log;
import android.util.Size;
import android.view.TextureView;
import android.view.ViewStub;
import android.widget.TextView;
import android.widget.Toast;
import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
import androidx.appcompat.app.AppCompatActivity;
import androidx.camera.core.CameraX;
import androidx.camera.core.ImageAnalysis;
import androidx.camera.core.ImageAnalysisConfig;
import androidx.camera.core.ImageProxy;
import androidx.camera.core.Preview;
import androidx.camera.core.PreviewConfig;
import androidx.core.app.ActivityCompat;
import java.nio.FloatBuffer;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
public class CameraActivity extends AppCompatActivity {
private static final String TAG = BuildConfig.LOGCAT_TAG;
private static final int TEXT_TRIM_SIZE = 4096;
private static final int REQUEST_CODE_CAMERA_PERMISSION = 200;
private static final String[] PERMISSIONS = {Manifest.permission.CAMERA};
private long mLastAnalysisResultTime;
protected HandlerThread mBackgroundThread;
protected Handler mBackgroundHandler;
protected Handler mUIHandler;
private TextView mTextView;
private StringBuilder mTextViewStringBuilder = new StringBuilder();
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_camera);
mTextView = findViewById(R.id.text);
mUIHandler = new Handler(getMainLooper());
startBackgroundThread();
if (ActivityCompat.checkSelfPermission(this, Manifest.permission.CAMERA)
!= PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, PERMISSIONS, REQUEST_CODE_CAMERA_PERMISSION);
} else {
setupCameraX();
}
}
@Override
protected void onPostCreate(@Nullable Bundle savedInstanceState) {
super.onPostCreate(savedInstanceState);
startBackgroundThread();
}
protected void startBackgroundThread() {
mBackgroundThread = new HandlerThread("ModuleActivity");
mBackgroundThread.start();
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
}
@Override
protected void onDestroy() {
stopBackgroundThread();
super.onDestroy();
}
protected void stopBackgroundThread() {
mBackgroundThread.quitSafely();
try {
mBackgroundThread.join();
mBackgroundThread = null;
mBackgroundHandler = null;
} catch (InterruptedException e) {
Log.e(TAG, "Error on stopping background thread", e);
}
}
@Override
public void onRequestPermissionsResult(
int requestCode, String[] permissions, int[] grantResults) {
if (requestCode == REQUEST_CODE_CAMERA_PERMISSION) {
if (grantResults[0] == PackageManager.PERMISSION_DENIED) {
Toast.makeText(
this,
"You can't use image classification example without granting CAMERA permission",
Toast.LENGTH_LONG)
.show();
finish();
} else {
setupCameraX();
}
}
}
private static final int TENSOR_WIDTH = 224;
private static final int TENSOR_HEIGHT = 224;
private void setupCameraX() {
final TextureView textureView =
((ViewStub) findViewById(R.id.camera_texture_view_stub))
.inflate()
.findViewById(R.id.texture_view);
final PreviewConfig previewConfig = new PreviewConfig.Builder().build();
final Preview preview = new Preview(previewConfig);
preview.setOnPreviewOutputUpdateListener(
new Preview.OnPreviewOutputUpdateListener() {
@Override
public void onUpdated(Preview.PreviewOutput output) {
textureView.setSurfaceTexture(output.getSurfaceTexture());
}
});
final ImageAnalysisConfig imageAnalysisConfig =
new ImageAnalysisConfig.Builder()
.setTargetResolution(new Size(TENSOR_WIDTH, TENSOR_HEIGHT))
.setCallbackHandler(mBackgroundHandler)
.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)
.build();
final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);
imageAnalysis.setAnalyzer(
new ImageAnalysis.Analyzer() {
@Override
public void analyze(ImageProxy image, int rotationDegrees) {
if (SystemClock.elapsedRealtime() - mLastAnalysisResultTime < 500) {
return;
}
final Result result = CameraActivity.this.analyzeImage(image, rotationDegrees);
if (result != null) {
mLastAnalysisResultTime = SystemClock.elapsedRealtime();
CameraActivity.this.runOnUiThread(
new Runnable() {
@Override
public void run() {
CameraActivity.this.handleResult(result);
}
});
}
}
});
CameraX.bindToLifecycle(this, preview, imageAnalysis);
}
private Module mModule;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;
@WorkerThread
@Nullable
protected Result analyzeImage(ImageProxy image, int rotationDegrees) {
Log.i(TAG, String.format("analyzeImage(%s, %d)", image, rotationDegrees));
if (mModule == null) {
Log.i(TAG, "Loading module from asset '" + BuildConfig.MODULE_ASSET_NAME + "'");
mModule = PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
mInputTensorBuffer = Tensor.allocateFloatBuffer(3 * TENSOR_WIDTH * TENSOR_HEIGHT);
mInputTensor =
Tensor.fromBlob(mInputTensorBuffer, new long[] {1, 3, TENSOR_WIDTH, TENSOR_HEIGHT});
}
final long startTime = SystemClock.elapsedRealtime();
TensorImageUtils.imageYUV420CenterCropToFloatBuffer(
image.getImage(),
rotationDegrees,
TENSOR_WIDTH,
TENSOR_HEIGHT,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
mInputTensorBuffer,
0,
MemoryFormat.CHANNELS_LAST);
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final float[] scores = outputTensor.getDataAsFloatArray();
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(scores, moduleForwardDuration, analysisDuration);
}
@UiThread
protected void handleResult(Result result) {
int ixs[] = Utils.topK(result.scores, 1);
String message =
String.format(
"forwardDuration:%d class:%s",
result.moduleForwardDuration, Constants.IMAGENET_CLASSES[ixs[0]]);
Log.i(TAG, message);
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
}
mTextView.setText(mTextViewStringBuilder.toString());
}
}

View File

@ -1,22 +0,0 @@
package org.pytorch.testapp;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
public final class LibtorchNativeClient {
public static void loadAndForwardModel(final String modelPath) {
NativePeer.loadAndForwardModel(modelPath);
}
private static class NativePeer {
static {
if (!NativeLoader.isInitialized()) {
NativeLoader.init(new SystemDelegate());
}
NativeLoader.loadLibrary("pytorch_testapp_jni");
}
private static native void loadAndForwardModel(final String modelPath);
}
}

View File

@ -1,171 +0,0 @@
package org.pytorch.testapp;
import android.content.Context;
import android.os.Bundle;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.SystemClock;
import android.util.Log;
import android.widget.TextView;
import androidx.annotation.Nullable;
import androidx.annotation.UiThread;
import androidx.annotation.WorkerThread;
import androidx.appcompat.app.AppCompatActivity;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.FloatBuffer;
import org.pytorch.Device;
import org.pytorch.IValue;
import org.pytorch.MemoryFormat;
import org.pytorch.Module;
import org.pytorch.PyTorchAndroid;
import org.pytorch.Tensor;
public class MainActivity extends AppCompatActivity {
private static final String TAG = BuildConfig.LOGCAT_TAG;
private static final int TEXT_TRIM_SIZE = 4096;
private TextView mTextView;
protected HandlerThread mBackgroundThread;
protected Handler mBackgroundHandler;
private Module mModule;
private FloatBuffer mInputTensorBuffer;
private Tensor mInputTensor;
private StringBuilder mTextViewStringBuilder = new StringBuilder();
private final Runnable mModuleForwardRunnable =
new Runnable() {
@Override
public void run() {
final Result result = doModuleForward();
runOnUiThread(
new Runnable() {
@Override
public void run() {
handleResult(result);
if (mBackgroundHandler != null) {
mBackgroundHandler.post(mModuleForwardRunnable);
}
}
});
}
};
public static String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e(TAG, "Error process asset " + assetName + " to file path");
}
return null;
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
if (BuildConfig.NATIVE_BUILD) {
final String modelFileAbsoluteFilePath =
new File(assetFilePath(this, BuildConfig.MODULE_ASSET_NAME)).getAbsolutePath();
LibtorchNativeClient.loadAndForwardModel(modelFileAbsoluteFilePath);
return;
}
setContentView(R.layout.activity_main);
mTextView = findViewById(R.id.text);
startBackgroundThread();
mBackgroundHandler.post(mModuleForwardRunnable);
}
protected void startBackgroundThread() {
mBackgroundThread = new HandlerThread(TAG + "_bg");
mBackgroundThread.start();
mBackgroundHandler = new Handler(mBackgroundThread.getLooper());
}
@Override
protected void onDestroy() {
stopBackgroundThread();
super.onDestroy();
}
protected void stopBackgroundThread() {
mBackgroundThread.quitSafely();
try {
mBackgroundThread.join();
mBackgroundThread = null;
mBackgroundHandler = null;
} catch (InterruptedException e) {
Log.e(TAG, "Error stopping background thread", e);
}
}
@WorkerThread
@Nullable
protected Result doModuleForward() {
if (mModule == null) {
final long[] shape = BuildConfig.INPUT_TENSOR_SHAPE;
long numElements = 1;
for (int i = 0; i < shape.length; i++) {
numElements *= shape[i];
}
mInputTensorBuffer = Tensor.allocateFloatBuffer((int) numElements);
mInputTensor =
Tensor.fromBlob(
mInputTensorBuffer, BuildConfig.INPUT_TENSOR_SHAPE, MemoryFormat.CHANNELS_LAST);
PyTorchAndroid.setNumThreads(1);
mModule =
BuildConfig.USE_VULKAN_DEVICE
? PyTorchAndroid.loadModuleFromAsset(
getAssets(), BuildConfig.MODULE_ASSET_NAME, Device.VULKAN)
: PyTorchAndroid.loadModuleFromAsset(getAssets(), BuildConfig.MODULE_ASSET_NAME);
}
final long startTime = SystemClock.elapsedRealtime();
final long moduleForwardStartTime = SystemClock.elapsedRealtime();
final Tensor outputTensor = mModule.forward(IValue.from(mInputTensor)).toTensor();
final long moduleForwardDuration = SystemClock.elapsedRealtime() - moduleForwardStartTime;
final float[] scores = outputTensor.getDataAsFloatArray();
final long analysisDuration = SystemClock.elapsedRealtime() - startTime;
return new Result(scores, moduleForwardDuration, analysisDuration);
}
static class Result {
private final float[] scores;
private final long totalDuration;
private final long moduleForwardDuration;
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
this.scores = scores;
this.moduleForwardDuration = moduleForwardDuration;
this.totalDuration = totalDuration;
}
}
@UiThread
protected void handleResult(Result result) {
String message = String.format("forwardDuration:%d", result.moduleForwardDuration);
mTextViewStringBuilder.insert(0, '\n').insert(0, message);
if (mTextViewStringBuilder.length() > TEXT_TRIM_SIZE) {
mTextViewStringBuilder.delete(TEXT_TRIM_SIZE, mTextViewStringBuilder.length());
}
mTextView.setText(mTextViewStringBuilder.toString());
}
}

View File

@ -1,14 +0,0 @@
package org.pytorch.testapp;
class Result {
public final float[] scores;
public final long totalDuration;
public final long moduleForwardDuration;
public Result(float[] scores, long moduleForwardDuration, long totalDuration) {
this.scores = scores;
this.moduleForwardDuration = moduleForwardDuration;
this.totalDuration = totalDuration;
}
}

View File

@ -1,28 +0,0 @@
package org.pytorch.testapp;
import java.util.Arrays;
public class Utils {
public static int[] topK(float[] a, final int topk) {
float values[] = new float[topk];
Arrays.fill(values, -Float.MAX_VALUE);
int ixs[] = new int[topk];
Arrays.fill(ixs, -1);
for (int i = 0; i < a.length; i++) {
for (int j = 0; j < topk; j++) {
if (a[i] > values[j]) {
for (int k = topk - 1; k >= j + 1; k--) {
values[k] = values[k - 1];
ixs[k] = ixs[k - 1];
}
values[j] = a[i];
ixs[j] = i;
break;
}
}
}
return ixs;
}
}

View File

@ -1,23 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".CameraActivity">
<ViewStub
android:id="@+id/camera_texture_view_stub"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout="@layout/texture_view"/>
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_gravity="top"
android:textSize="16sp"
android:textStyle="bold"
android:textColor="#ff0000"/>
</FrameLayout>

View File

@ -1,17 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_gravity="top"
android:textSize="14sp"
android:background="@android:color/black"
android:textColor="@android:color/white" />
</FrameLayout>

View File

@ -1,5 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<TextureView xmlns:android="http://schemas.android.com/apk/res/android"
android:id="@+id/texture_view"
android:layout_width="match_parent"
android:layout_height="0dp" />

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.7 KiB

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="colorPrimary">#008577</color>
<color name="colorPrimaryDark">#00574B</color>
<color name="colorAccent">#D81B60</color>
</resources>

View File

@ -1,3 +0,0 @@
<resources>
<string name="app_name">PyTest</string>
</resources>

View File

@ -1,11 +0,0 @@
<resources>
<!-- Base application theme. -->
<style name="AppTheme" parent="Theme.AppCompat.Light.DarkActionBar">
<!-- Customize your theme here. -->
<item name="colorPrimary">@color/colorPrimary</item>
<item name="colorPrimaryDark">@color/colorPrimaryDark</item>
<item name="colorAccent">@color/colorAccent</item>
</style>
</resources>

View File

@ -1,24 +0,0 @@
from torchvision import models
import torch
print(torch.version.__version__)
resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
resnet18.eval()
resnet18_traced = torch.jit.trace(resnet18, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/resnet18.pt"
)
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
resnet50.eval()
torch.jit.trace(resnet50, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/resnet50.pt"
)
mobilenet2q = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
mobilenet2q.eval()
torch.jit.trace(mobilenet2q, torch.rand(1, 3, 224, 224)).save(
"app/src/main/assets/mobilenet2q.pt"
)

View File

@ -1,27 +0,0 @@
"""
This is a script for PyTorch Android custom selective build test. It prepares
MobileNetV2 TorchScript model, and dumps root ops used by the model for custom
build script to create a tailored build which only contains these used ops.
"""
import yaml
from torchvision import models
import torch
# Download and trace the model.
model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
model.eval()
example = torch.rand(1, 3, 224, 224)
# TODO: create script model with `torch.jit.script`
traced_script_module = torch.jit.trace(model, example)
# Save traced TorchScript model.
traced_script_module.save("MobileNetV2.pt")
# Dump root ops used by the model (for custom build optimization).
ops = torch.jit.export_opnames(traced_script_module)
with open("MobileNetV2.yaml", "w") as output:
yaml.dump(ops, output)

View File

@ -471,7 +471,7 @@ struct CachingHostAllocatorImpl {
virtual B* get_free_block(size_t size) {
auto index = size_index(size);
std::lock_guard<std::mutex> g(free_list_[index].mutex_);
if (free_list_[index].list_.size() > 0) {
if (!free_list_[index].list_.empty()) {
B* block = free_list_[index].list_.back();
free_list_[index].list_.pop_back();
block->allocated_ = true;

View File

@ -639,7 +639,7 @@ IntArrayRef MPSHeapAllocatorImpl::getBufferShape(const void* ptr) {
std::lock_guard<std::recursive_mutex> lock(m_mutex);
BufferBlock* buffer_block = get_allocated_buffer_block(ptr);
if (buffer_block && buffer_block->shape.size() > 0) {
if (buffer_block && !buffer_block->shape.empty()) {
return IntArrayRef{buffer_block->shape};
}
return IntArrayRef();

View File

@ -1456,7 +1456,6 @@ static inline at::MemoryFormat determine_backend_memory_format(
}
break;
case ConvBackend::Mps:
case ConvBackend::MpsTranspose:
if (mps_conv_use_channels_last(input, weight)) {
#ifdef USE_MPS
if (!mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) {

View File

@ -880,7 +880,7 @@ struct FullBidirectionalLayer
step_inputs = input_w.unbind(0);
auto fw_result = layer_(
step_inputs, input_hidden.first, params.first, true);
TORCH_CHECK(fw_result.outputs.size() > 0, "Expected sequence length to be larger than 0 in RNN");
TORCH_CHECK(!fw_result.outputs.empty(), "Expected sequence length to be larger than 0 in RNN");
auto fw_output = at::stack(fw_result.outputs, 0);
input_w = params.second.linear_ih(input);
step_inputs = input_w.unbind(0);
@ -895,7 +895,7 @@ struct FullBidirectionalLayer
step_inputs = input.unbind(0);
auto fw_result = layer_(step_inputs, input_hidden.first, params.first);
TORCH_CHECK(fw_result.outputs.size() > 0, "Expected sequence length to be larger than 0 in RNN");
TORCH_CHECK(!fw_result.outputs.empty(), "Expected sequence length to be larger than 0 in RNN");
auto fw_output = at::stack(fw_result.outputs, 0);
auto rev_step_inputs = reverse(std::move(step_inputs));
auto rev_result =

View File

@ -485,13 +485,13 @@ void _assert_async_cpu(const Tensor& self) {
void _assert_async_msg_cpu(const Tensor& self, std::string_view assert_msg) {
TORCH_CHECK(
native::is_nonzero(self),
assert_msg != "" ? assert_msg : "Assertion is failed");
!assert_msg.empty() ? assert_msg : "Assertion is failed");
}
void _assert_scalar(const Scalar& scalar, std::string_view assert_msg) {
TORCH_SYM_CHECK(
scalar.toSymBool(),
assert_msg != "" ? assert_msg : "Assertion is failed");
!assert_msg.empty() ? assert_msg : "Assertion is failed");
}
Tensor _functional_assert_scalar(

View File

@ -5,6 +5,7 @@
#ifdef _WIN32
#define _USE_MATH_DEFINES
#include <cmath>
#include <math.h>
#endif // _WIN32
#include <ATen/cpu/vec/vec.h>

View File

@ -67,7 +67,7 @@ void stack_serial_kernel_impl(Tensor& result, TensorListType tensors, int64_t di
// - tensors dtype is Double or Float
template <typename TensorListType>
bool can_use_native_serial_stack_impl(Tensor& result, TensorListType tensors, int64_t dim) {
TORCH_CHECK(tensors.size() > 0, "expected a non-empty list of Tensors");
TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
const Tensor& first_tensor = tensors[0];
// stack dimension should be in range [0,firstTensor.dim())
// dim == firstTensor.dim() is a valid input, but it is handled by default code path

View File

@ -467,9 +467,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
alpha,
(&result != &self) ? self.const_data_ptr<scalar_t>() : nullptr,
activation_to_gemm_and_blas_arg(activation));
}
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
} else {
okay = at::cuda::blas::gemm_and_bias<scalar_t>(
args.transa == 't',
args.transb == 't',
args.m,
@ -486,7 +485,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
args.result->data_ptr<scalar_t>(),
args.result_ld,
activation_to_gemm_and_blas_arg(activation)
);
);
}
});
}
if (!okay) {

View File

@ -433,7 +433,7 @@ struct TensorDescriptorListParams {
int64_t batch_sizes_sum; // == sum(batch_sizes)
bool is_input_packed() const {
return batch_sizes.size() != 0;
return !batch_sizes.empty();
}
void set(
@ -465,8 +465,7 @@ struct TensorDescriptorListParams {
#ifndef USE_CUDNN_RNN_V8_API
// TODO: check x for consistency with input_size?
std::vector<TensorDescriptor> descriptors(Tensor x) const {
auto is_input_packed = batch_sizes.size() != 0;
if (is_input_packed) {
if (is_input_packed()) {
return rnn_descriptor_sequence(x, batch_sizes);
} else {
return rnn_descriptor(x[0], seq_length);
@ -474,8 +473,7 @@ struct TensorDescriptorListParams {
}
#else
auto descriptors(Tensor x) const {
auto is_input_packed = batch_sizes.size() != 0;
if (is_input_packed) {
if (is_input_packed()) {
return rnn_descriptor_sequence(
x, mini_batch, batch_sizes, seq_length, x.size(-1));
} else {
@ -1253,7 +1251,7 @@ int64_t _cudnn_rnn_flatten_weight_prologue(
// typeMetaToScalarType is a surprisingly nontrivial function. We should
// avoid it if we can.
TORCH_CHECK(
weight_arr.size() > 0,
!weight_arr.empty(),
"copy_weights_to_flat_buf_views: cannot flatten empty weight list");
rnn.set(
@ -1306,7 +1304,7 @@ copy_weights_to_flat_buf_views(
bool set_orig_weights_to_flat_buf,
bool allow_type_change /*=false*/,
bool include_bias /*=true*/) {
TORCH_CHECK(weight_arr.size() > 0, "empty weight list");
TORCH_CHECK(!weight_arr.empty(), "empty weight list");
auto handle = getCudnnHandle();
RNNDescriptorParams rnn;
RNNDescriptor rnn_desc;
@ -1390,7 +1388,7 @@ Tensor _cudnn_rnn_flatten_weight(
int64_t fn_num_layers,
bool batch_first,
bool fn_bidirectional) {
TORCH_CHECK(weight_arr.size() > 0, "empty weight list");
TORCH_CHECK(!weight_arr.empty(), "empty weight list");
// returns flat weight_buf
return std::get<0>(copy_weights_to_flat_buf_views(
weight_arr,
@ -1417,7 +1415,7 @@ Tensor _cudnn_rnn_flatten_weight_meta(
int64_t num_layers,
bool batch_first,
bool bidirectional) {
TORCH_CHECK(weight_arr.size() > 0, "empty weight list");
TORCH_CHECK(!weight_arr.empty(), "empty weight list");
auto handle = getCudnnHandle();
RNNDescriptorParams rnn;
RNNDescriptor rnn_desc;
@ -1498,7 +1496,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
datatype);
#else
auto input_size = input_r.size(-1);
auto packed = fn_batch_sizes.size() != 0;
auto packed = !fn_batch_sizes.empty();
fn.rnn.set(
fn_mode,
input_size,
@ -1520,7 +1518,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _cudnn_rnn(
}
// TODO: can batch_first be a wrapper around this function?
auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
auto is_input_packed = !fn.tensors.batch_sizes.empty();
if (batch_first && !is_input_packed) {
input = input.transpose(0, 1);
}
@ -1775,7 +1773,7 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
datatype);
#else
auto cudnn_input_size = input_r.size(-1);
auto packed = fn_batch_sizes.size() != 0;
auto packed = !fn_batch_sizes.empty();
fn.rnn.set(
fn_mode,
cudnn_input_size,
@ -1797,7 +1795,7 @@ std::tuple<Tensor, Tensor, Tensor> _cudnn_rnn_backward_input(
TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
}
auto is_input_packed = fn_batch_sizes.size() != 0;
auto is_input_packed = !fn_batch_sizes.empty();
if (batch_first && !is_input_packed) {
input = input.transpose(0, 1);
grad_output = grad_output.transpose(0, 1);
@ -2004,7 +2002,7 @@ std::vector<Tensor> _cudnn_rnn_backward_weight(
datatype);
#else
auto cudnn_input_size = input_r.size(-1);
auto packed = fn_batch_sizes.size() != 0;
auto packed = !fn_batch_sizes.empty();
fn.rnn.set(
fn_mode,
cudnn_input_size,
@ -2025,7 +2023,7 @@ std::vector<Tensor> _cudnn_rnn_backward_weight(
TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
}
auto is_input_packed = fn_batch_sizes.size() != 0;
auto is_input_packed = !fn_batch_sizes.empty();
if (batch_first && !is_input_packed) {
input = input.transpose(0, 1);
output = output.transpose(0, 1);

View File

@ -203,7 +203,7 @@ struct TensorDescriptorListParams {
int64_t batch_sizes_sum;
bool is_input_packed() const {
return batch_sizes.size() != 0;
return !batch_sizes.empty();
}
void set(IntArrayRef input_sizes, IntArrayRef batch_sizes_, bool batch_first) {
@ -227,8 +227,7 @@ struct TensorDescriptorListParams {
}
std::vector<TensorDescriptor> descriptors(Tensor x) const {
auto is_input_packed = batch_sizes.size() != 0;
if (is_input_packed) {
if (is_input_packed()) {
return rnn_descriptor_sequence(x, batch_sizes);
} else {
return rnn_descriptor(x[0], seq_length);
@ -545,7 +544,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> miopen_rnn(
TORCH_CHECK(!cx.defined(), "miopen_rnn: illegal defined cx for non-LSTM RNN.");
}
auto is_input_packed = fn.tensors.batch_sizes.size() != 0;
auto is_input_packed = !fn.tensors.batch_sizes.empty();
if (batch_first && !is_input_packed) {
input = input.transpose(0, 1);
}
@ -656,7 +655,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> miopen_rnn_backward_input(
TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
}
auto is_input_packed = fn_batch_sizes.size() != 0;
auto is_input_packed = !fn_batch_sizes.empty();
if (batch_first && !is_input_packed) {
input = input.transpose(0, 1);
grad_output = grad_output.transpose(0, 1);
@ -773,7 +772,7 @@ std::vector<Tensor> miopen_rnn_backward_weight(
TORCH_CHECK(!cx.defined(), "rnn: illegal defined cx for non-LSTM RNN");
}
auto is_input_packed = fn_batch_sizes.size() != 0;
auto is_input_packed = !fn_batch_sizes.empty();
if (batch_first && !is_input_packed) {
input = input.transpose(0, 1);
output = output.transpose(0, 1);

View File

@ -457,7 +457,7 @@ static std::tuple<Tensor, Tensor, Tensor> mkldnn_rnn(
int64_t mode, int64_t hidden_size,
int64_t num_layers, bool has_biases, bool batch_first, double dropout_p,
bool train, bool bidirectional, IntArrayRef batch_sizes) {
TORCH_CHECK(batch_sizes.size() == 0, "mkldnn_rnn doesn't support packed input");
TORCH_CHECK(batch_sizes.empty(), "mkldnn_rnn doesn't support packed input");
if (static_cast<ideep::rnn_kind>(mode) != ideep::rnn_kind::LSTM) {
TORCH_CHECK(!cx_.defined(), "mkldnn_rnn: illegal defined cx for non-LSTM RNN");
}

View File

@ -1160,7 +1160,7 @@ void MetalKernelFunction::dispatch(uint64_t length, std::optional<uint64_t> grou
}
void MetalKernelFunction::dispatch(c10::ArrayRef<uint64_t> length, c10::OptionalArrayRef<uint64_t> group_size) {
TORCH_CHECK(length.size() > 0 && length.size() < 4, "Dispatch dimentions must be less than 3 and non-empty");
TORCH_CHECK(!length.empty() && length.size() < 4, "Dispatch dimentions must be less than 3 and non-empty");
TORCH_CHECK(!group_size.has_value() || group_size->size() == length.size(),
"size and group_size must have same number of dimentions");
const auto max_tg_size = getMaxThreadsPerThreadgroup();

View File

@ -180,7 +180,7 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
usage:MTLResourceUsageRead | MTLResourceUsageWrite];
}
if (state_steps.size() > 0) {
if (!state_steps.empty()) {
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc);
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
}
@ -230,7 +230,7 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
usage:MTLResourceUsageWrite | MTLResourceUsageRead];
}
if (state_steps.size() > 0) {
if (!state_steps.empty()) {
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
[computeEncoder useResource:getMTLBufferStorage(state_steps[tensor_index]) usage:MTLResourceUsageRead];
}

View File

@ -493,7 +493,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
// Reduction axes
axes = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
axes[0] = @0;
} else if (!keepdim && use_dim && dim_value.size() > 0) {
} else if (!keepdim && use_dim && !dim_value.empty()) {
int64_t num_reduce_dims = dim_value.size();
num_output_dims = num_input_dims;
@ -528,7 +528,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
correction_n *= input_shape[wrap_dim];
}
// (3, 4, 5) --> (3, 5)
} else if ((keepdim && !use_dim) || (keepdim && use_dim && dim_value.size() <= 0)) {
} else if ((keepdim && !use_dim) || (keepdim && use_dim && dim_value.empty())) {
num_output_dims = 0;
int64_t num_reduce_dims = 0;
set_axes(axes, num_reduce_dims, dim_value, input_shape.size());
@ -540,7 +540,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
correction_n *= input_shape[i];
}
// scalar --> vector case [[1.0034567]]
} else if (keepdim && use_dim && dim_value.size() > 0) {
} else if (keepdim && use_dim && !dim_value.empty()) {
int64_t num_reduce_dims = dim_value.size();
num_output_dims = num_input_dims;

View File

@ -168,7 +168,7 @@ TORCH_IMPL_FUNC(cat_out_mps)
TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
"torch.cat(): input types can't be cast to the desired output type ",
out.scalar_type());
TORCH_CHECK(inputs.size() > 0, "torch.cat(): invalid number of inputs ", inputs.size());
TORCH_CHECK(!inputs.empty(), "torch.cat(): invalid number of inputs ", inputs.size());
dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);

View File

@ -272,7 +272,7 @@ static std::tuple<Tensor, Tensor, Tensor> _unique_impl_mps(const Tensor& self,
}
int64_t lengthScalar = length.item<int64_t>() + 1; // length actually holds max index, add 1
if (output.sizes().size() != 0) {
if (!output.sizes().empty()) {
output = at::slice(output, dim, 0, lengthScalar);
}
if (return_counts)

View File

@ -215,7 +215,7 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self,
// when we create/run the view graph.
IntArrayRef base_shape = mps::updateTensorBaseShape(self);
TORCH_INTERNAL_ASSERT(
base_shape.size() > 0, "Failed to update the base shape of tensor's buffer at ", self.storage().data());
!base_shape.empty(), "Failed to update the base shape of tensor's buffer at ", self.storage().data());
return result;
}

View File

@ -6981,7 +6981,7 @@
device_check: NoCheck # TensorIterator
variants: function
dispatch:
CPU, CUDA: rsub
CPU, CUDA, MPS: rsub
autogen: rsub.Tensor_out
- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)

View File

@ -429,7 +429,7 @@ Tensor sparse_compressed_tensor_with_dims(
compressed_indices_size.push_back(compressed_size / blocksize[d0] + 1);
values_size.append(DimVector(blocksize));
} else {
TORCH_CHECK(blocksize.size() == 0, "sparse_compressed_tensor_with_dims: blocksize cannot be specified for non-block layout ", layout_);
TORCH_CHECK(blocksize.empty(), "sparse_compressed_tensor_with_dims: blocksize cannot be specified for non-block layout ", layout_);
compressed_indices_size.push_back(size[compressedDimension(layout_, size, dense_dim)] + 1);
}

View File

@ -573,7 +573,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_
}
const bool sum_all_sparse_dim = (input_sparse_dim == sparse_dims_to_sum_size);
const bool sum_dense_dim = (dense_dims_to_sum_v.size() > 0);
const bool sum_dense_dim = !dense_dims_to_sum_v.empty();
const bool sum_sparse_dim = (sparse_dims_to_sum_size > 0);
if (sum_all_sparse_dim) {

View File

@ -479,7 +479,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();
@ -705,7 +705,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();
@ -940,7 +940,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
@ -1163,7 +1163,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();
auto softmax_d = at::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
@ -1393,7 +1393,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{static_cast<signed char>(q.get_device())};
auto opts = q.options();

View File

@ -40,7 +40,7 @@ class BatchNormPackedContext final : virtual public VulkanPackedContext,
static BatchNormPackedContext pack(c10::impl::GenericList);
const c10::impl::GenericList unpack() const override {
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
return unpacked_;
}

View File

@ -1283,7 +1283,7 @@ Tensor Conv2dOpContext::run(const Tensor& input_arg) const {
Conv2dOpContext::State Conv2dOpContext::unpack() const {
const c10::impl::GenericList unpacked_ = conv_context_.unpack();
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
return Conv2dOpContext::State(
unpacked_.get(Conv2dPackedContext::Unpacked::Weight).toTensor(),

View File

@ -115,7 +115,7 @@ class Conv2dPackedContext final : virtual public VulkanPackedContext,
static Conv2dPackedContext pack(c10::impl::GenericList);
const c10::impl::GenericList unpack() const override {
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
return unpacked_;
}
@ -275,7 +275,7 @@ class Conv1dPackedContext final : virtual public VulkanPackedContext,
static Conv1dPackedContext pack(c10::impl::GenericList);
const c10::impl::GenericList unpack() const override {
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
return unpacked_;
}

View File

@ -240,7 +240,7 @@ const c10::impl::GenericList GruPackedContext::unpack() const {
packed_linear_context.toCustomClass<LinearPackedContext>()->unpack();
TORCH_CHECK(
unpacked_linear_context.size() > 0u,
!unpacked_linear_context.empty(),
"unpacked_linear_context does not have any elements!");
params_cpu.emplace_back(

View File

@ -282,7 +282,7 @@ const c10::impl::GenericList LstmPackedContext::unpack() const {
packed_linear_context.toCustomClass<LinearPackedContext>()->unpack();
TORCH_CHECK(
unpacked_linear_context.size() > 0u,
!unpacked_linear_context.empty(),
"unpacked_linear_context does not have any elements!");
params_cpu.emplace_back(

View File

@ -89,7 +89,7 @@ class LinearPackedContext final : virtual public VulkanPackedContext,
static LinearPackedContext pack(c10::impl::GenericList);
const c10::impl::GenericList unpack() const override {
TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
TORCH_CHECK(!unpacked_.empty(), "unpacked_ does not have any elements!");
return unpacked_;
}

View File

@ -17,10 +17,13 @@ if(BUILD_TEST)
if(WIN32 AND test_src MATCHES "^.*\.hip$")
set_source_files_properties(${test_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
hip_add_executable(${test_name} "${test_src}")
set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH})
set_target_properties(${test_name} PROPERTIES HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH})
else()
add_executable(${test_name} "${test_src}")
endif()
if(test_src MATCHES "^.*\.hip$")
set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX)
endif()
target_link_libraries(${test_name} ${C10_CUDA_LIB} ${C10_LIB} gmock gtest gtest_main)
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:${test_name}>)
if(INSTALL_TEST)

View File

@ -1939,10 +1939,13 @@ if(BUILD_TEST)
set(HIP_HIPCC_FLAGS ${BASE_HIPCC_FLAGS})
set_source_files_properties(${test_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
hip_add_executable(${test_name} "${test_src}")
set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH})
set_target_properties(${test_name} PROPERTIES HIP_ARCHITECTURES ${PYTORCH_ROCM_ARCH})
else()
add_executable(${test_name} "${test_src}")
endif()
if(test_src MATCHES "^.*\.hip$")
set_target_properties(${test_name} PROPERTIES LINKER_LANGUAGE CXX)
endif()
target_link_libraries(${test_name} torch_library gtest_main)
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE} ${Caffe2_HIP_INCLUDE})

View File

@ -51,6 +51,10 @@ else()
set(XPU_ENABLE_KINETO FALSE)
endif()
if(NOT WIN32)
if(WIN32)
if(${SYCL_COMPILER_VERSION} GREATER_EQUAL 20250101)
set(XPU_ENABLE_KINETO TRUE)
endif()
else()
set(XPU_ENABLE_KINETO TRUE)
endif()

View File

@ -859,3 +859,5 @@ API Reference
.. automodule:: torch.export.experimental
.. automodule:: torch.export.passes
.. autofunction:: torch.export.passes.move_to_device_pass
.. automodule:: torch.export.pt2_archive
.. automodule:: torch.export.pt2_archive.constants

View File

@ -1,5 +1,4 @@
from torch import cond # noqa: F401
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401
from torch._higher_order_ops.map import ( # noqa: F401
_stack_pytree,
_unstack_pytree,

View File

@ -2384,15 +2384,6 @@
"torch.utils.collect_env": [
"namedtuple"
],
"torch.utils.data.datapipes.gen_pyi": [
"Any",
"Dict",
"List",
"Set",
"Tuple",
"Union",
"defaultdict"
],
"torch.utils.data.datapipes.utils.snapshot": [
"IterDataPipe",
"apply_random_seed"

View File

@ -19,7 +19,7 @@ target_compile_definitions(test_aoti_abi_check PRIVATE USE_GTEST)
# WARNING: DO NOT LINK torch!!!
# The purpose is to check if the used aten/c10 headers are writtern in a header-only way
target_link_libraries(test_aoti_abi_check PRIVATE gtest)
target_link_libraries(test_aoti_abi_check PRIVATE gtest_main)
target_include_directories(test_aoti_abi_check PRIVATE ${ATen_CPU_INCLUDE})
if(INSTALL_TEST)

View File

@ -55,7 +55,7 @@ add_custom_command(
target_link_libraries(test_aoti_inference PRIVATE
torch
gtest
gtest_main
-Wl,--no-as-needed aoti_custom_class
)

View File

@ -50,7 +50,7 @@ endif()
add_executable(test_api ${TORCH_API_TEST_SOURCES})
target_include_directories(test_api PRIVATE ${ATen_CPU_INCLUDE})
target_link_libraries(test_api PRIVATE torch gtest gmock)
target_link_libraries(test_api PRIVATE torch gtest_main gmock)
if(USE_CUDA)
target_compile_definitions(test_api PRIVATE "USE_CUDA")

View File

@ -7,7 +7,7 @@ if(USE_DISTRIBUTED AND NOT WIN32)
add_executable(test_dist_autograd ${DIST_AUTOGRAD_TEST_SOURCES})
target_include_directories(test_dist_autograd PRIVATE ${ATen_CPU_INCLUDE})
target_link_libraries(test_dist_autograd PRIVATE torch gtest)
target_link_libraries(test_dist_autograd PRIVATE torch gtest_main)
if(USE_CUDA)
target_compile_definitions(test_dist_autograd PRIVATE "USE_CUDA")

View File

@ -125,7 +125,7 @@ if(USE_MKLDNN)
target_link_libraries(test_jit PRIVATE caffe2::mkldnn)
endif()
set(JIT_TEST_DEPENDENCIES torch gtest jitbackend_test backend_with_compiler gmock)
set(JIT_TEST_DEPENDENCIES torch gtest_main jitbackend_test backend_with_compiler gmock)
if(MSVC)
list(APPEND JIT_TEST_DEPENDENCIES onnx_library)

View File

@ -28,7 +28,7 @@ add_executable(test_lazy
# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_lazy PRIVATE USE_GTEST)
set(LAZY_TEST_DEPENDENCIES torch gtest)
set(LAZY_TEST_DEPENDENCIES torch gtest_main)
target_link_libraries(test_lazy PRIVATE ${LAZY_TEST_DEPENDENCIES})
target_include_directories(test_lazy PRIVATE ${ATen_CPU_INCLUDE})

View File

@ -21,7 +21,7 @@ target_include_directories(
${ATen_CPU_INCLUDE}
)
target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest backend_with_compiler_runtime)
target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest_main backend_with_compiler_runtime)
if(LINUX)
target_link_libraries(test_lite_interpreter_runtime PRIVATE "-Wl,--no-as-needed,$<TARGET_FILE:backend_with_compiler_runtime>,--as-needed")

View File

@ -17,7 +17,7 @@ add_executable(test_nativert
# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_nativert PRIVATE USE_GTEST)
set(NATIVERT_TEST_DEPENDENCIES torch gtest)
set(NATIVERT_TEST_DEPENDENCIES torch gtest_main)
target_link_libraries(test_nativert PRIVATE ${NATIVERT_TEST_DEPENDENCIES})
target_link_libraries(test_nativert PRIVATE fmt::fmt-header-only)

View File

@ -5,7 +5,7 @@ set(TORCH_RPC_TEST_SOURCES
${TORCH_RPC_TEST_DIR}/test_wire_serialization.cpp
)
set(TORCH_RPC_TEST_DEPENDENCY_LIBS
torch gtest
torch gtest_main
)
if(USE_GLOO)

View File

@ -39,7 +39,7 @@ add_executable(test_tensorexpr
${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp
${TENSOREXPR_TEST_SRCS})
target_link_libraries(test_tensorexpr PRIVATE torch gtest)
target_link_libraries(test_tensorexpr PRIVATE torch gtest_main)
target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE})
target_compile_definitions(test_tensorexpr PRIVATE USE_GTEST)

View File

@ -4,7 +4,6 @@ import collections
import copy
import functools
import itertools
import unittest
from typing import Any, Optional, Union
import torch
@ -12,13 +11,13 @@ import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
from torch.nn.parallel.scatter_gather import _is_namedtuple
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
DoubleLinear,
FSDPTest,
FSDPTestMultiThread,
get_devtype,
MLP,
)
from torch.testing._internal.common_utils import run_tests
@ -28,10 +27,13 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
device_type = torch.device(get_devtype())
class TestFullyShardAutograd(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.get_device_module(device_type).device_count())
def _reduce_1d_partial_grads(
self, module: nn.Module, group: Optional[dist.ProcessGroup] = None
@ -58,7 +60,7 @@ class TestFullyShardAutograd(FSDPTest):
local_batch_size = 2
global_batch_size, dim = (self.world_size * local_batch_size, 24)
model = DoubleLinear(dim=dim, use_second_linear=True)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
fully_shard(model.lin1, reshard_after_forward=reshard_after_forward)
fully_shard(model, reshard_after_forward=reshard_after_forward)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
@ -68,7 +70,7 @@ class TestFullyShardAutograd(FSDPTest):
for iter_idx in range(10):
# Use all forward outputs in the loss/backward for the first half
# of the iterations and only the 1st forward output for the rest
global_inp = torch.rand((global_batch_size, dim), device="cuda")
global_inp = torch.rand((global_batch_size, dim), device=device_type)
local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach()
@ -104,7 +106,7 @@ class TestFullyShardAutograd(FSDPTest):
local_batch_size, dim = (2, 24)
global_batch_size = self.world_size * local_batch_size
model = DoubleLinear(dim=dim, use_second_linear=False)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
fully_shard(model.lin1, reshard_after_forward=reshard_after_forward)
fully_shard(model.lin2, reshard_after_forward=reshard_after_forward)
fully_shard(model, reshard_after_forward=reshard_after_forward)
@ -113,7 +115,7 @@ class TestFullyShardAutograd(FSDPTest):
torch.manual_seed(1) # same on all ranks
for iter_idx in range(10):
global_inp = torch.rand((global_batch_size, dim), device="cuda")
global_inp = torch.rand((global_batch_size, dim), device=device_type)
local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach()
@ -214,7 +216,7 @@ class TestFullyShardAutograd(FSDPTest):
Module(dim),
FromContainerType(container_type),
)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
for module in model:
fully_shard(module)
fully_shard(model)
@ -223,7 +225,7 @@ class TestFullyShardAutograd(FSDPTest):
torch.manual_seed(1) # same on all ranks
for iter_idx in range(10):
global_inp = torch.rand((global_batch_size, dim), device="cuda")
global_inp = torch.rand((global_batch_size, dim), device=device_type)
local_inp = global_inp[
self.rank * local_batch_size : (self.rank + 1) * local_batch_size
].detach()
@ -245,7 +247,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
def world_size(self) -> int:
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_post_acc_grad_hook_runs(self):
param_name_to_hook_count = collections.defaultdict(int)
@ -260,7 +262,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
param_hook = functools.partial(hook, param_name)
param.register_post_accumulate_grad_hook(param_hook)
inp = torch.randn((2, 8), device="cuda")
inp = torch.randn((2, 8), device=device_type)
model(inp).sum().backward()
param_names = {param_name for param_name, _ in model.named_parameters()}
self.assertEqual(param_names, set(param_name_to_hook_count.keys()))
@ -271,7 +273,7 @@ class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_post_acc_grad_hook_optim_parity(self):
@ -283,7 +285,7 @@ class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
for module in itertools.chain(ref_model.layers, [ref_model]):
fully_shard(module)
optim_kwargs = {"lr": 1e-2, "foreach": False}
@ -312,7 +314,7 @@ class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
param.register_post_accumulate_grad_hook(optim_hook)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type)
for _ in range(10):
ref_loss = ref_model(inp).sum()
ref_loss.backward()

View File

@ -11,7 +11,7 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLPStack
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLPStack
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
@ -20,6 +20,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
device_type = torch.device(get_devtype())
class _TestClipGradNormBase(FSDPTest):
def _test_clip_grad_norm(
self,
@ -33,7 +36,7 @@ class _TestClipGradNormBase(FSDPTest):
dp_mesh: Optional[DeviceMesh] = None,
):
vector_norm_fn = functools.partial(torch.linalg.vector_norm, ord=norm_type)
dp_mesh = dp_mesh or init_device_mesh("cuda", (self.world_size,))
dp_mesh = dp_mesh or init_device_mesh(device_type.type, (self.world_size,))
torch.manual_seed(42 + dp_mesh.get_local_rank() + 1)
for _ in range(10):
ref_optim.zero_grad()
@ -91,7 +94,7 @@ class _TestClipGradNormBase(FSDPTest):
class TestClipGradNormWorldSize2(_TestClipGradNormBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_clip_grad_norm_1d(self):
@ -99,14 +102,16 @@ class TestClipGradNormWorldSize2(_TestClipGradNormBase):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_model = replicate(copy.deepcopy(model).to(device_type))
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
inp = torch.randint(0, model.model_args.vocab_size, (3, 16), device="cuda")
inp = torch.randint(
0, model.model_args.vocab_size, (3, 16), device=device_type
)
self._test_clip_grad_norm(
1, norm_type, ref_model, ref_optim, model, optim, inp
)
@ -115,14 +120,14 @@ class TestClipGradNormWorldSize2(_TestClipGradNormBase):
class TestClipGradNormWorldSize4(_TestClipGradNormBase):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 4)
return min(torch.get_device_module(device_type).device_count(), 4)
@skip_if_lt_x_gpu(4)
def test_clip_grad_norm_2d(self):
for norm_type in (2, 1, 3, float("inf")):
dp_size = 2
global_mesh = init_device_mesh(
"cuda",
device_type.type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
@ -132,7 +137,7 @@ class TestClipGradNormWorldSize4(_TestClipGradNormBase):
# has some more significant numeric differences from the TP
model = MLPStack(16, with_seq_parallel=True)
ref_model = replicate(
copy.deepcopy(model).cuda(), process_group=dp_mesh.get_group()
copy.deepcopy(model).to(device_type), process_group=dp_mesh.get_group()
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
model.parallelize(
@ -142,7 +147,7 @@ class TestClipGradNormWorldSize4(_TestClipGradNormBase):
reshard_after_forward=True,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
inp = torch.randn(2, 16, device="cuda")
inp = torch.randn(2, 16, device=device_type)
self._test_clip_grad_norm(
0.5, norm_type, ref_model, ref_optim, model, optim, inp, dp_mesh
)

View File

@ -3,7 +3,6 @@
import copy
import functools
import itertools
import unittest
from typing import Callable, Optional, Union
import torch
@ -35,7 +34,6 @@ from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.experimental import implicit_replication
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
@ -60,6 +58,12 @@ c10d_ops = torch.ops.c10d
# For recording FSDP events like unshard or post-backward
EventType = tuple[str, str, TrainingState]
from torch.testing._internal.common_fsdp import get_devtype
device_type = torch.device(get_devtype())
device_module = torch.get_device_module(device_type)
class TestFullyShardCollectiveOps(FSDPTestMultiThread):
@property
@ -68,7 +72,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
@property
def device(self) -> torch.device:
return torch.device("cuda:0")
return torch.device(device_type.type, 0)
def _get_param_sizes(self) -> list[torch.Size]:
# For world size 128, the fp32 all-gather and reduce-scatter testing
@ -116,11 +120,14 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
fsdp_param_group.lazy_init()
return fsdp_param_group
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_all_gather_fp32(self):
param_sizes = self._get_param_sizes()
default_stream = torch.cuda.current_stream()
stream1, stream2 = torch.cuda.Stream(), torch.cuda.Stream()
default_stream = device_module.current_stream()
stream1, stream2 = (
device_module.Stream(),
device_module.Stream(),
)
for async_op, streams, reshard_after_forward in itertools.product(
(False, True),
((default_stream, default_stream), (stream1, stream2)),
@ -146,8 +153,8 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
param_sizes: list[torch.Size],
reshard_after_forward: Union[bool, int],
async_op: bool,
all_gather_copy_in_stream: torch.cuda.Stream,
all_gather_stream: torch.cuda.Stream,
all_gather_copy_in_stream,
all_gather_stream,
):
def all_gather(fsdp_param_group: FSDPParamGroup, group: dist.ProcessGroup):
all_gather_result = foreach_all_gather(
@ -202,11 +209,11 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
)
check_all_gathered_params(orig_params, module)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_reduce_scatter_fp32(self):
param_sizes = self._get_param_sizes()
default_stream = torch.cuda.current_stream()
stream = torch.cuda.Stream()
default_stream = device_module.current_stream()
stream = device_module.Stream()
for reduce_scatter_stream in (default_stream, stream):
self._test_reduce_scatter(
param_sizes,
@ -214,11 +221,11 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
reduce_scatter_dtype=torch.float32,
)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_reduce_scatter_fp16(self):
param_sizes = self._get_param_sizes()
default_stream = torch.cuda.current_stream()
stream = torch.cuda.Stream()
default_stream = torch.get_device_module(device_type).current_stream()
stream = device_module.Stream()
for reduce_scatter_stream in (default_stream, stream):
self._test_reduce_scatter(
param_sizes,
@ -229,7 +236,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
def _test_reduce_scatter(
self,
param_sizes: list[torch.Size],
reduce_scatter_stream: torch.cuda.Stream,
reduce_scatter_stream,
reduce_scatter_dtype: torch.dtype,
):
# Set up the reference parameters and construct the FSDP group
@ -248,7 +255,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
unsharded_grads = [torch.ones_like(param) * self.rank for param in orig_params]
group = fsdp_param_group.mesh_info.shard_process_group
self.assertEqual(group.size(), self.world_size)
all_reduce_stream = torch.cuda.Stream()
all_reduce_stream = device_module.Stream()
(
_,
_,
@ -271,7 +278,9 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
all_reduce_grads=True,
partial_reduce_output=None,
)
torch.cuda.current_stream().wait_event(post_reduce_event)
torch.get_device_module(device_type).current_stream().wait_event(
post_reduce_event
)
# Check reduce-scatter correctness
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
@ -295,7 +304,7 @@ class TestFullyShardCollectiveOps(FSDPTestMultiThread):
class TestFullyShardCommunication(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_fully_shard_communication_count(self):
@ -327,7 +336,7 @@ class TestFullyShardCommunication(FSDPTest):
# We construct `num_blocks` plus 1 FSDP states/communication groups
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
with CommDebugMode() as fwd_comm_mode:
loss = model(inp)
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
@ -364,7 +373,7 @@ class TestFullyShardCommunication(FSDPTest):
)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
with CommDebugMode() as fwd_comm_mode:
loss = model(inp)
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
@ -395,7 +404,7 @@ class TestFullyShardCommunication(FSDPTest):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0, weight_tying=False)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
@ -405,7 +414,7 @@ class TestFullyShardCommunication(FSDPTest):
model.set_reduce_scatter_divide_factor(divide_factor)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
for _ in range(10):
ref_loss = ref_model(inp).sum()
@ -441,7 +450,7 @@ class TestFullyShardCommunication(FSDPTest):
):
torch.manual_seed(42)
model_args = ModelArgs()
model = Transformer(model_args)
model = Transformer(model_args).to(device_type)
fully_shard_fn = functools.partial(
fully_shard, reshard_after_forward=not set_reshard_after_forward
)
@ -459,7 +468,7 @@ class TestFullyShardCommunication(FSDPTest):
)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
with CommDebugMode() as fwd_comm_mode:
loss = model(inp)
fwd_comm_counts = fwd_comm_mode.get_comm_counts()
@ -484,7 +493,7 @@ class TestFullyShardCommunication(FSDPTest):
class TestFullyShardPrefetch(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_fully_shard_backward_prefetch(self):
@ -640,7 +649,7 @@ class TestFullyShardPrefetch(FSDPTest):
fully_shard(model[1].lin1, reshard_after_forward=reshard_after_forward)
fully_shard(model[1].lin2, reshard_after_forward=reshard_after_forward)
fully_shard(model, reshard_after_forward=reshard_after_forward)
inp = torch.randn((4, dim), device="cuda")
inp = torch.randn((4, dim), device=device_type.type)
events: list[EventType] = []
unshard_with_record = self._get_unshard_with_record(
FSDPParamGroup.unshard, events
@ -901,7 +910,10 @@ class TestFullyShardPrefetch(FSDPTest):
FSDPParamGroup.post_backward, events
)
inp = torch.randint(
0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda"
0,
model_args.vocab_size,
(2, model_args.max_seq_len),
device=device_type.type,
)
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
@ -981,7 +993,7 @@ class TestFullyShardPrefetch(FSDPTest):
post_backward_with_record = self._get_post_backward_with_record(
FSDPParamGroup.post_backward, events
)
inp = torch.randn((2, 16), device="cuda")
inp = torch.randn((2, 16), device=device_type.type)
with patch_unshard(unshard_with_record), patch_post_backward(
post_backward_with_record
):
@ -1019,7 +1031,7 @@ class TestFullyShardPrefetch(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_backward_misprefetch(self):
torch.manual_seed(42)
model = MLP(dim=16, device="cuda")
model = MLP(dim=16, device=device_type)
ref_model = copy.deepcopy(model)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
fully_shard(model.in_proj)
@ -1033,7 +1045,7 @@ class TestFullyShardPrefetch(FSDPTest):
model.in_proj.set_modules_to_backward_prefetch([model.out_proj])
torch.manual_seed(self.rank + 1)
inp = torch.randn((2, 16), device="cuda")
inp = torch.randn((2, 16), device=device_type.type)
for _ in range(3):
ref_optim.zero_grad()
ref_loss = ref_model(inp).sum()
@ -1065,7 +1077,10 @@ class TestFullyShardPrefetch(FSDPTest):
fully_shard(model, reshard_after_forward=reshard_after_forward)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
inp = torch.randint(
0, model_args.vocab_size, (2, model_args.max_seq_len), device="cuda"
0,
model_args.vocab_size,
(2, model_args.max_seq_len),
device=device_type.type,
)
return model, optim, inp
@ -1115,7 +1130,7 @@ class TestFullyShardPrefetch(FSDPTest):
class TestFullyShardUnshardMultiProcess(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_unshard_async(self):
@ -1169,10 +1184,10 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
self.mlps.mlp3.unshard(async_op=True)
return self.mlps([y1, y2, y3], [work1, work2, work3])
mesh = init_device_mesh("cuda", (self.world_size,))
mesh = init_device_mesh(device_type.type, (self.world_size,))
batch_size, dim = 2, 8
torch.manual_seed(42)
ref_model = replicate(ReduceModel(dim, mesh).cuda())
ref_model = replicate(ReduceModel(dim, mesh).to(device_type))
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
torch.manual_seed(42)
model = ReduceModel(dim, mesh)
@ -1180,10 +1195,10 @@ class TestFullyShardUnshardMultiProcess(FSDPTest):
fully_shard(model.mlps.mlp2, reshard_after_forward=False)
fully_shard(model.mlps.mlp3, reshard_after_forward=False)
fully_shard(model.mlps)
replicate(model.cuda())
replicate(model.to(device_type))
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((batch_size, dim), device="cuda")
inp = torch.randn((batch_size, dim), device=device_type.type)
for _ in range(10):
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
@ -1200,7 +1215,7 @@ class TestFullyShardUnshardMultiThread(FSDPTestMultiThread):
def world_size(self) -> int:
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_unshard_no_param_group(self):
# Check that we can call `unshard()` on a module with no parameter
# group / no managed parameters without erroring
@ -1211,7 +1226,7 @@ class TestFullyShardUnshardMultiThread(FSDPTestMultiThread):
handle = model.unshard(async_op=True)
handle.wait()
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_unshard_without_lazy_init(self):
torch.manual_seed(42)
model = MLP(4)

View File

@ -31,7 +31,7 @@ from torch.testing._internal.common_distributed import (
skip_if_lt_x_gpu,
sm_is_or_higher_than,
)
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
@ -40,6 +40,8 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
from torch.testing._internal.inductor_utils import HAS_GPU
device_type = torch.device(get_devtype())
log = logging.getLogger(__name__)
@ -59,9 +61,9 @@ class Mod(torch.nn.Module):
super().__init__()
self.encoder = torch.nn.Sequential(
torch.nn.Linear(28 * 28, 1024, device="cuda"),
torch.nn.Linear(1024, 1024, device="cuda"),
torch.nn.Linear(1024, 4096, device="cuda"),
torch.nn.Linear(28 * 28, 1024, device=device_type),
torch.nn.Linear(1024, 1024, device=device_type),
torch.nn.Linear(1024, 4096, device=device_type),
)
def forward(self, x):
@ -104,10 +106,10 @@ class TestFullyShardCompileCompute(FSDPTest):
torch.distributed.barrier()
torch._dynamo.config.skip_fsdp_hooks = skip_fsdp_hooks
torch._dynamo.trace_rules.check = patched_trace_rules_check
model = MLP(4)
model = MLP(4).to(device_type)
fully_shard(model)
model.compile()
model(torch.randn((4, 4), device="cuda"))
model(torch.randn((4, 4), device=device_type))
torch.distributed.barrier()
torch._dynamo.config.skip_fsdp_hooks = original_skip_fsdp_hooks
torch._dynamo.trace_rules.check = orig_trace_rules_check
@ -127,7 +129,10 @@ class TestFullyShardCompile(FSDPTest):
def skipTestForOldSm(self):
# Assumption: This test class is only run on GPU. See `HAS_GPU` check at
# the top of the class.
device = torch.device("cuda", self.rank % torch.cuda.device_count())
device = torch.device(
device_type.type,
self.rank % torch.get_device_module(device_type).device_count(),
)
if not sm_is_or_higher_than(device, 8, 0):
self.skipTest("bf16 requires sm >= 8.0")
@ -139,7 +144,7 @@ class TestFullyShardCompile(FSDPTest):
(torch.nn.Linear(1, 1),), # module: Tuple[nn.Module, ...],
None, # mesh_info: FSDPMeshInfo,
None, # post_forward_mesh_info: Optional[FSDPMeshInfo],
torch.device("cuda"), # device: torch.device,
device_type, # device: torch.device,
None, # shard_placement_fn: Optional[Callable],
None, # mp_policy: MixedPrecisionPolicy,
None, # offload_policy: OffloadPolicy,
@ -592,11 +597,11 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
torch.manual_seed(self.rank)
fsdp_config = {}
model = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.Linear(hidden_dim, hidden_dim, device=device_type),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.Linear(hidden_dim, hidden_dim, device=device_type),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.Linear(hidden_dim, hidden_dim, device=device_type),
)
fully_shard(model, reshard_after_forward=True, **fsdp_config)
optim = torch.optim.SGD(model.parameters(), lr=1e-4)
@ -604,7 +609,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
def input_creation_fn():
torch.manual_seed(self.rank)
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
inp = torch.randn((2, hidden_dim), device=device_type, requires_grad=False)
return inp
return model_init_fn, input_creation_fn
@ -641,11 +646,11 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
super().__init__()
self.param1 = nn.Parameter(
torch.zeros(
hidden_dim, hidden_dim, dtype=torch.float, device="cuda"
hidden_dim, hidden_dim, dtype=torch.float, device=device_type
)
)
self.param2 = nn.Parameter(
torch.zeros(hidden_dim, dtype=torch.float, device="cuda")
torch.zeros(hidden_dim, dtype=torch.float, device=device_type)
)
def forward(self, x):
@ -680,7 +685,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
def model_init_fn():
torch.manual_seed(self.rank)
fsdp_config = {}
mesh = init_device_mesh("cuda", (self.world_size,))
mesh = init_device_mesh(device_type.type, (self.world_size,))
model = TestModule(n_layers=3)
for mod in model.layers:
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
@ -692,7 +697,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
def input_creation_fn():
torch.manual_seed(self.rank)
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
inp = torch.randn((2, hidden_dim), device=device_type, requires_grad=False)
return inp
return model_init_fn, input_creation_fn
@ -854,7 +859,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
def model_init_fn():
torch.manual_seed(self.rank)
fsdp_config = {}
mesh = init_device_mesh("cuda", (self.world_size,))
mesh = init_device_mesh(device_type.type, (self.world_size,))
model_args = ModelArgs(
vocab_size=vocab_size,
n_layers=n_layers,
@ -883,7 +888,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
def input_creation_fn():
torch.manual_seed(self.rank)
inp = torch.randint(
0, vocab_size, (2, seq_len), device="cuda", requires_grad=False
0, vocab_size, (2, seq_len), device=device_type, requires_grad=False
)
return inp
@ -1092,7 +1097,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
new_child = torch.compile(child)
setattr(m.encoder, name, new_child)
m = FSDP(m, sharding_strategy=ShardingStrategy.FULL_SHARD, use_orig_params=True)
inp = torch.randn(32, 784, device="cuda")
inp = torch.randn(32, 784, device=device_type)
m(inp)

View File

@ -5,7 +5,6 @@ import copy
import functools
import math
import threading
import unittest
from typing import Any, Optional, Union
import torch
@ -15,18 +14,21 @@ import torch.utils._pytree as pytree
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
FSDPTest,
FSDPTestMultiThread,
get_devtype,
MLP,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.two_tensor import TwoTensor
device_type = torch.device(get_devtype())
def two_tensor_fsdp_pre_all_gather_v1(
self, mesh: DeviceMesh
) -> tuple[tuple[torch.Tensor, ...], Any]:
@ -222,7 +224,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
def _test_all_gather_extensions_train_parity(self, reshard_after_forward: bool):
torch.manual_seed(42)
model = self._init_two_tensor_mlp()
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=True)
fully_shard_fn = functools.partial(
fully_shard, reshard_after_forward=reshard_after_forward
@ -234,7 +236,7 @@ class TestFullyShardAllGatherExtensionsMultiProcess(
check_sharded_parity(self, ref_model, model)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 8), device="cuda")
inp = torch.randn((2, 8), device=device_type)
for iter_idx in range(10):
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
@ -261,9 +263,9 @@ class TestFullyShardAllGatherExtensionsMultiThread(
@property
def device(self) -> torch.device:
return torch.device("cuda:0")
return torch.device(device_type)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_all_gather_extensions_end_to_end(self):
with self._patch_two_tensor_fsdp_all_gather(pre_all_gather_version=1):
self.run_subtests(
@ -297,13 +299,13 @@ class TestFullyShardAllGatherExtensionsMultiThread(
# Run a few iterations to check for errors
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 8), device="cuda")
inp = torch.randn((2, 8), device=device_type)
for _ in range(3):
model(inp).sum().backward()
optim.step()
optim.zero_grad()
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_all_gather_extensions_monkey_patch(self):
tls = threading.local()
tls.ran_pre_all_gather = False
@ -368,14 +370,14 @@ class TestFullyShardAllGatherExtensionsMultiThread(
# Run a few iterations to check for errors
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 8), device="cuda")
inp = torch.randn((2, 8), device=device_type)
for _ in range(3):
model(inp).sum().backward()
optim.step()
optim.zero_grad()
assert tls.ran_pre_all_gather
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_all_gather_extension_outer_size_stride(self):
"""
NOTE: We cannot easily test the incorrect case where the user-defined
@ -395,19 +397,19 @@ class TestFullyShardAllGatherExtensionsMultiThread(
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, fused=True)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 3), device="cuda")
inp = torch.randn((2, 3), device=device_type)
loss = model(inp).sum()
loss.backward()
optim.step()
optim.zero_grad()
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_all_gather_extension_hsdp_mesh(self):
tls = threading.local()
replicate_size = 2
shard_size = self.world_size // replicate_size
mesh = init_device_mesh(
"cuda",
device_type.type,
(replicate_size, shard_size),
mesh_dim_names=("dp_replicate", "dp_shard"),
)
@ -456,7 +458,7 @@ class TestFullyShardAllGatherExtensionsMultiThread(
local_param
)
inp = torch.randn((2, 8), device="cuda")
inp = torch.randn((2, 8), device=device_type)
model(inp)
# Check that FSDP passes only the shard mesh to the pre-all-gather
self.assertEqual(tls.mesh.ndim, 1)

View File

@ -18,6 +18,7 @@ from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
FSDPTest,
get_devtype,
MLP,
patch_reduce_scatter,
patch_register_post_backward_hook_backward,
@ -26,10 +27,13 @@ from torch.testing._internal.common_fsdp import (
from torch.testing._internal.common_utils import run_tests
device_type = torch.device(get_devtype())
class TestFullyShardFrozen(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_train_mixed_requires_grad_per_group(self):
@ -66,7 +70,7 @@ class TestFullyShardFrozen(FSDPTest):
if "bias" not in param_name:
param.requires_grad_(False)
ref_model = replicate(
copy.deepcopy(model).cuda(),
copy.deepcopy(model).to(device_type),
device_ids=[self.rank],
find_unused_parameters=freeze_after_init,
)
@ -110,7 +114,7 @@ class TestFullyShardFrozen(FSDPTest):
return orig_backward(*args, **kwargs)
torch.manual_seed(42 + self.rank + 1)
device = torch.device("cuda")
device = device_type
with patch_reduce_scatter(
reduce_scatter
), patch_register_post_backward_hook_backward(backward_with_count):
@ -156,7 +160,7 @@ class TestFullyShardFrozen(FSDPTest):
modules += [nn.Linear(lin_dim, lin_dim), nn.ReLU()]
model = nn.Sequential(*modules)
ref_model = replicate(
copy.deepcopy(model).cuda(),
copy.deepcopy(model).to(device_type),
device_ids=[self.rank],
find_unused_parameters=True,
)
@ -184,7 +188,7 @@ class TestFullyShardFrozen(FSDPTest):
_set_requires_grad(ref_model, False)
num_iters, no_grad_iter_idx = (3, 1)
torch.manual_seed(42 + self.rank)
inp = torch.randn((8, lin_dim), device="cuda")
inp = torch.randn((8, lin_dim), device=device_type)
with patch_register_post_backward_hook_backward(backward_with_count):
for iter_idx in range(num_iters):
losses: list[torch.Tensor] = []
@ -242,7 +246,9 @@ class TestFullyShardFrozen(FSDPTest):
torch.manual_seed(42)
model = MultiForwardModule(torch.device("cpu"))
ref_model = replicate(copy.deepcopy(model).cuda(), device_ids=[self.rank])
ref_model = replicate(
copy.deepcopy(model).to(device_type), device_ids=[self.rank]
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, nn.Linear):
@ -250,7 +256,7 @@ class TestFullyShardFrozen(FSDPTest):
fully_shard(model, reshard_after_forward=reshard_after_forward)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
for iter_idx in range(10):
inp = torch.randn((8, 5), device="cuda")
inp = torch.randn((8, 5), device=device_type)
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))

View File

@ -12,10 +12,13 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype, MLP
from torch.testing._internal.common_utils import run_tests
device_type = torch.device(get_devtype())
class TestFullyShardGradientScaler(FSDPTest):
@skip_if_lt_x_gpu(4)
def test_gradient_scaler(self):
@ -27,16 +30,16 @@ class TestFullyShardGradientScaler(FSDPTest):
def _test_gradient_scaler(self, has_inf: bool, test_2d: bool):
torch.manual_seed(0)
model = nn.Sequential(
*[nn.Linear(4, 4, device="cuda", bias=False) for _ in range(2)]
*[nn.Linear(4, 4, device=device_type, bias=False) for _ in range(2)]
)
for layer in model:
fully_shard(layer)
fully_shard(model)
input = torch.randn([4, 4], device="cuda")
input = torch.randn([4, 4], device=device_type)
if test_2d:
mesh_2d = init_device_mesh(
"cuda", (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
device_type.type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
dp_mesh, tp_mesh = mesh_2d["dp"], mesh_2d["tp"]
model = nn.Sequential(MLP(2), MLP(2), MLP(2))
@ -56,10 +59,10 @@ class TestFullyShardGradientScaler(FSDPTest):
for module in model:
fully_shard(module, mesh=dp_mesh)
fully_shard(model, mesh=dp_mesh)
input = torch.randn((2,), device="cuda")
input = torch.randn((2,), device=device_type)
loss = model(input).sum()
scaler = GradScaler(init_scale=2.0, enabled=True)
scaler = GradScaler(init_scale=2.0, enabled=True, device=device_type.type)
opt = torch.optim.Adam(model.parameters(), lr=1e-2)
scaler.scale(loss).backward()
inv_scale = scaler._scale.double().reciprocal().float()

View File

@ -12,7 +12,7 @@ from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import implicit_replication
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
@ -20,6 +20,8 @@ from torch.testing._internal.common_utils import (
)
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
@ -72,14 +74,14 @@ class A(nn.Module):
class Y(nn.Module):
def __init__(self) -> None:
super().__init__()
p = torch.randn(10, device="cuda")
p = torch.randn(10, device=device_type)
self.p = nn.Parameter(p)
class X(nn.Module):
def __init__(self) -> None:
super().__init__()
q = torch.randn(10, device="cuda")
q = torch.randn(10, device=device_type)
self.q = nn.Parameter(q)
self.y = Y()
@ -95,15 +97,15 @@ def _generate_model_and_input() -> nn.Module:
dim = 8
torch.manual_seed(42)
addend = torch.randn((dim, dim), device="cuda")
addend = torch.randn((dim, dim), device=device_type)
torch.manual_seed(70)
subend = torch.randn((dim, dim), device="cuda")
subend = torch.randn((dim, dim), device=device_type)
model = A(dim, addend, subend).cuda()
model = A(dim, addend, subend).to(device_type)
torch.manual_seed(84)
inp = torch.randn((dim, dim), device="cuda")
inp = torch.randn((dim, dim), device=device_type)
return model, inp
@ -229,7 +231,7 @@ class TestFullyShardIgnoreParams(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_ddp_A_fsdp_B_ddp_C(self):
default_pg = dist.distributed_c10d._get_default_group()
mesh = init_device_mesh("cuda", mesh_shape=(default_pg.size(),))
mesh = init_device_mesh(device_type.type, mesh_shape=(default_pg.size(),))
ref_model, ref_inp = _generate_model_and_input()

View File

@ -2,7 +2,6 @@
import copy
import itertools
import unittest
from typing import cast, Optional
import torch
@ -37,8 +36,8 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
)
from torch.distributed.tensor.placement_types import _StridedShard
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, get_devtype, MLP
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
@ -47,6 +46,9 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
device_type = torch.device(get_devtype())
class TestFullyShardDeviceTensor(FSDPTestMultiThread):
"""Tests that tensor parameters are moved to the expected device."""
@ -54,17 +56,19 @@ class TestFullyShardDeviceTensor(FSDPTestMultiThread):
def world_size(self) -> int:
return 1
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_move_states_to_device_tensor(self):
model = MLP(8, torch.device("cpu"), with_buffer=True)
for tensor in itertools.chain(model.parameters(), model.buffers()):
self.assertEqual(tensor.device, torch.device("cpu"))
fully_shard(model)
cuda_device = torch.device("cuda", torch.cuda.current_device())
accelerator_device = torch.device(
device_type.type, torch.get_device_module(device_type).current_device()
)
for tensor in itertools.chain(model.parameters(), model.buffers()):
self.assertEqual(tensor.device, cuda_device)
self.assertEqual(tensor.device, accelerator_device)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_move_states_to_device_ignored_param_device(self):
cpu_device = torch.device("cpu")
model = MLP(8, cpu_device, with_buffer=True)
@ -72,10 +76,12 @@ class TestFullyShardDeviceTensor(FSDPTestMultiThread):
fully_shard(model, ignored_params=set(ignored_params))
for tensor in ignored_params:
self.assertEqual(tensor.device, cpu_device)
cuda_device = torch.device("cuda", torch.cuda.current_device())
model.to(torch.device("cuda"))
accelerator_device = torch.device(
device_type.type, torch.get_device_module(device_type).current_device()
)
model.to(device_type)
for tensor in ignored_params:
self.assertEqual(tensor.device, cuda_device)
self.assertEqual(tensor.device, accelerator_device)
class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
@ -85,12 +91,14 @@ class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
def world_size(self) -> int:
return 4
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_move_states_to_device_dtensor_valid(self):
assert self.world_size >= 4, f"{self.world_size}"
dp_size = 2
global_mesh = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
device_type.type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
model = MLP(8, torch.device("cpu"), with_buffer=True)
@ -99,31 +107,35 @@ class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
tp_mesh,
{"in_proj": ColwiseParallel(), "out_proj": RowwiseParallel()},
)
cuda_device = torch.device("cuda", torch.cuda.current_device())
accelerator_device = torch.device(
device_type.type, torch.get_device_module(device_type).current_device()
)
for tensor in itertools.chain(model.parameters(), model.buffers()):
if isinstance(tensor, DTensor):
# DTensor constructor moves to the mesh's device
self.assertEqual(tensor.device, cuda_device)
self.assertEqual(tensor._local_tensor.device, cuda_device)
self.assertEqual(tensor.device, accelerator_device)
self.assertEqual(tensor._local_tensor.device, accelerator_device)
else:
self.assertEqual(tensor.device, torch.device("cpu"))
fully_shard(model, mesh=dp_mesh)
for tensor in itertools.chain(model.parameters(), model.buffers()):
self.assertEqual(tensor.device, cuda_device)
self.assertEqual(tensor.device, accelerator_device)
if isinstance(tensor, DTensor):
self.assertEqual(tensor._local_tensor.device, cuda_device)
self.assertEqual(tensor._local_tensor.device, accelerator_device)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_move_states_to_device_dtensor_invalid(self):
assert self.world_size >= 4, f"{self.world_size}"
dp_size = 2
global_cuda_mesh = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
global_accelerator_mesh = init_device_mesh(
device_type.type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
global_cpu_mesh = init_device_mesh(
"cpu", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
)
dp_mesh = global_cuda_mesh["dp"]
dp_mesh = global_accelerator_mesh["dp"]
tp_mesh = global_cpu_mesh["tp"] # mismatched meshes!
model = MLP(8, torch.device("cpu"), with_buffer=True)
parallelize_module(
@ -135,7 +147,10 @@ class TestFullyShardDeviceDTensor(FSDPTestMultiThread):
self.assertEqual(tensor.device, torch.device("cpu"))
if isinstance(tensor, DTensor):
self.assertEqual(tensor._local_tensor.device, torch.device("cpu"))
regex = r"Requires DTensor to have mesh of the same type as the FSDP mesh but got cpu for DTensor and cuda for FSDP"
regex = (
rf"Requires DTensor to have mesh of the same type as the FSDP mesh but got "
rf"cpu for DTensor and {device_type.type} for FSDP"
)
with self.assertRaisesRegex(ValueError, regex):
fully_shard(model, mesh=dp_mesh)
@ -147,17 +162,17 @@ class TestFullyShardMeshArg(FSDPTestMultiThread):
def world_size(self) -> int:
return 4
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_invalid_mesh_ndim(self):
mesh = init_device_mesh("cuda", (self.world_size, 1, 1))
mesh = init_device_mesh(device_type.type, (self.world_size, 1, 1))
model = MLP(8)
regex = r"fully\_shard expects a 1D or 2D DeviceMesh but got DeviceMesh"
with self.assertRaisesRegex(ValueError, regex):
fully_shard(model, mesh=mesh)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_2d_mesh_without_mesh_dim_names(self):
mesh = init_device_mesh("cuda", (self.world_size // 2, 2))
mesh = init_device_mesh(device_type.type, (self.world_size // 2, 2))
model = MLP(8)
regex = "Please init the 2D mesh for HSDP with mesh_dim_names specified"
with self.assertRaisesRegex(AssertionError, regex):
@ -171,7 +186,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
def world_size(self) -> int:
return 1
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_managed_modules_single(self):
model = MLP(8)
# Assume calling `fully_shard` on `model`
@ -179,7 +194,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
expected_managed_modules = list(model.modules())
self._check_managed_modules(managed_modules, expected_managed_modules)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_managed_modules_nested(self):
model = nn.Sequential(*[MLP(8) for _ in range(2)])
fully_shard(model[0])
@ -188,7 +203,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
expected_managed_modules = list(model[1].modules()) + [model]
self._check_managed_modules(managed_modules, expected_managed_modules)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_managed_modules_nested_fully_shard_and_replicate(self):
model = nn.Sequential(*[MLP(8) for _ in range(3)])
replicate(model[0])
@ -198,7 +213,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
expected_managed_modules = list(model[1].modules()) + [model]
self._check_managed_modules(managed_modules, expected_managed_modules)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_managed_modules_duplicate(self):
mlp = MLP(8)
model = nn.Sequential(mlp, mlp) # duplicate MLP
@ -208,7 +223,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
expected_managed_modules = list(mlp.modules()) + [model]
self._check_managed_modules(managed_modules, expected_managed_modules)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_managed_modules_list_of_mlps(self):
model = nn.Sequential(*[MLP(8) for _ in range(5)])
# Assume calling `fully_shard` on `[model[0], model[1], model[2]]`
@ -232,7 +247,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
# Check set comparison since we do not require anything about the order
self.assertEqual(set(managed_modules), set(expected_managed_modules))
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_managed_states_shared_params_and_buffers(self):
model = nn.Sequential(*[MLP(8, with_buffer=True) for _ in range(3)])
model[0].in_proj.weight = model[1].in_proj.weight
@ -245,7 +260,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
expected_buffers = list(model.buffers()) # de-dups shared
self._check_managed_states(params, buffers, expected_params, expected_buffers)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_managed_states_nested_fully_shard(self):
model = nn.Sequential(*[MLP(8, with_buffer=True) for _ in range(2)])
fully_shard(model[0])
@ -256,7 +271,7 @@ class TestFullyShardManagedModulesAndStates(FSDPTestMultiThread):
expected_buffers = list(model[1].buffers())
self._check_managed_states(params, buffers, expected_params, expected_buffers)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_managed_states_list_of_mlps(self):
model = nn.Sequential(*[MLP(8, with_buffer=True) for _ in range(5)])
# Assume calling `fully_shard` on `[model[0], model[1], model[2]]`
@ -292,7 +307,7 @@ class TestFullyShardParamModuleInfos(FSDPTestMultiThread):
def world_size(self) -> int:
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_get_param_module_infos_shared_params(self):
model = nn.Sequential(*[MLP(8) for _ in range(2)])
model[0].in_proj.weight = model[1].in_proj.weight
@ -313,7 +328,7 @@ class TestFullyShardParamModuleInfos(FSDPTestMultiThread):
self.assertEqual(len(param_module_infos), len(expected_param_module_infos))
self.assertEqual(param_module_infos, expected_param_module_infos)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_get_param_module_infos_duplicates(self):
mlp = MLP(8)
model = nn.Sequential(mlp, mlp) # shared MLP
@ -341,7 +356,7 @@ class TestFullyShardParamModuleInfos(FSDPTestMultiThread):
ParamModuleInfo(mlp.out_proj, "bias", [], []),
]
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_get_param_module_infos_list_of_mlps(self):
model = nn.Sequential(*[MLP(8) for _ in range(2)])
managed_modules = _get_managed_modules((model[0], model[1]))
@ -367,7 +382,7 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
def world_size(self) -> int:
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_shard_tensor_parameters(self):
# Use odd dim sizes to test uneven shards
model = nn.Sequential(*[MLP(3, dim_multiplier=3) for _ in range(3)])
@ -387,7 +402,7 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
self, orig_params: list[nn.Parameter], sharded_params: list[nn.Parameter]
):
self.assertEqual(len(orig_params), len(sharded_params))
global_mesh = init_device_mesh("cuda", (self.world_size,))
global_mesh = init_device_mesh(device_type.type, (self.world_size,))
for orig_param, sharded_param in zip(orig_params, sharded_params):
self.assertIsInstance(sharded_param, DTensor)
self.assertEqual(sharded_param.device_mesh, global_mesh)
@ -397,17 +412,19 @@ class TestFullyShardShardedParameterTensor(FSDPTestMultiThread):
chunks = torch.chunk(orig_param, self.world_size, dim=0)
self.assertEqual(sharded_param._local_tensor, chunks[self.rank])
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_raise_scalar_parameter(self):
"""Tests raising an exception when the model has scalar parameters."""
model = nn.Sequential(*[MLP(3, dim_multiplier=3) for _ in range(3)])
model.register_parameter("scalar_p", nn.Parameter(torch.tensor(1.0).cuda()))
model.register_parameter(
"scalar_p", nn.Parameter(torch.tensor(1.0).to(device_type))
)
with self.assertRaisesRegex(
ValueError, "Change scalar_p to a 1D tensor with numel equal to 1."
):
fully_shard(model)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_raise_noncontiguous_parameter(self):
"""
Tests raising an exception when the model has non-contiguous
@ -425,11 +442,13 @@ class TestFullyShardShardedParameterDTensor(FSDPTestMultiThread):
def world_size(self) -> int:
return 4
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_shard_dtensor_parameters(self):
dp_size = 2 if self.world_size > 2 else 1
global_mesh = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
device_type.type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
# Use odd dim sizes to test uneven shards
@ -468,7 +487,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
def world_size(self) -> int:
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_is_root(self):
"""
Tests that ``_is_root`` is set correctly after lazy initialization.
@ -497,7 +516,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
all_states, [root_state, model0_in_proj_state, model0_out_proj_state]
)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_module_and_param_fqns(self):
"""
Tests that the module and parameter FQNs are computed correctly after
@ -555,7 +574,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
model0_out_proj_param_fqns, {"0.out_proj.weight", "0.out_proj.bias"}
)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_double_lazy_init(self):
model = nn.Sequential(MLP(8), MLP(8))
fully_shard(model[0].in_proj)
@ -571,7 +590,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
with self.assertRaisesRegex(RuntimeError, regex):
root_state._lazy_init()
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_multi_module_root(self):
model = nn.Sequential(MLP(8), MLP(8))
fully_shard([model[0], model[1]])
@ -580,7 +599,7 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
with self.assertRaisesRegex(RuntimeError, regex):
root_state._lazy_init()
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_reset_sharded_param_in_lazy_init(self):
class MyModel(nn.Module):
def __init__(self):
@ -607,11 +626,11 @@ class TestFullyShardLazyInit(FSDPTestMultiThread):
fully_shard(model.layer2)
fully_shard(model)
model.layer1.to_empty(device="cuda")
model.layer2.to_empty(device="cuda")
model.layer1.to_empty(device=device_type.type)
model.layer2.to_empty(device=device_type.type)
model.init_weight_norm()
inp = torch.randn(3, 3, device="cuda")
inp = torch.randn(3, 3, device=device_type.type)
loss = model(inp).sum()
loss.backward()
@ -621,10 +640,10 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
def world_size(self) -> int:
return 4
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_meta_device_1d_init(self):
default_pg = torch.distributed.distributed_c10d._get_default_group()
mesh = init_device_mesh("cuda", mesh_shape=(default_pg.size(),))
mesh = init_device_mesh(device_type.type, mesh_shape=(default_pg.size(),))
# Test both even sharding (8) and uneven sharding (3)
for mlp_dim in (8, 3):
@ -652,12 +671,14 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
self.assertEqual(param.device, torch.device("meta"))
self._test_to_empty_and_reset_parameters(model, mesh, mlp_dim)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_meta_device_2d_init(self):
assert self.world_size >= 4, f"{self.world_size}"
dp_size = 2
global_mesh = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
device_type.type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
@ -685,7 +706,9 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
self, model: nn.Module, mesh: DeviceMesh, mlp_dim: int
):
# Check that we can materialize it on GPU with empty values
device = torch.device("cuda", torch.cuda.current_device())
device = torch.device(
device_type.type, torch.get_device_module(device_type).current_device()
)
model.to_empty(device=device)
for param in model.parameters():
self.assertEqual(param.device, device)
@ -706,14 +729,14 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
self.assertNotEqual(buffer, torch.ones_like(buffer) * const)
# Check that we can run an iteration without erroring
inp = torch.randn((4, mlp_dim), device="cuda")
inp = torch.randn((4, mlp_dim), device=device_type.type)
model(inp).sum().backward()
optim.step()
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_invalid_meta_device_init(self):
default_pg = torch.distributed.distributed_c10d._get_default_group()
mesh = init_device_mesh("cuda", mesh_shape=(default_pg.size(),))
mesh = init_device_mesh(device_type.type, mesh_shape=(default_pg.size(),))
mlp_dim = 8
with torch.device("meta"):
model = nn.Sequential(MLP(mlp_dim, with_buffer=True), MLP(mlp_dim))
@ -722,7 +745,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
fully_shard(model[0], mesh=mesh)
fully_shard(model[1], mesh=mesh)
fully_shard(model, mesh=mesh)
inp = torch.randn((4, mlp_dim), device="cuda")
inp = torch.randn((4, mlp_dim), device=device_type.type)
error_regex = (
"FSDP parameters should be materialized from meta device before training, "
"but the following were still on meta device: "
@ -731,7 +754,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
with self.assertRaisesRegex(RuntimeError, error_regex):
model(inp)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_rank0_broadcast_meta_device_init(self):
model_args = ModelArgs(dropout_p=0.0)
# Assume we have a CPU full state dict on rank 0
@ -743,7 +766,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
self.assertEqual(param.device, torch.device("cpu"))
# Initialize the sharded model on meta device
fsdp_mesh = init_device_mesh("cuda", (self.world_size,))
fsdp_mesh = init_device_mesh(device_type.type, (self.world_size,))
with torch.device("meta"):
model = Transformer(model_args)
for module in model.modules():
@ -763,7 +786,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
for (param_name, full_param), sharded_meta_param in zip(
full_sd.items(), meta_sharded_sd.values()
):
full_param = full_param.detach().cuda()
full_param = full_param.detach().to(device_type)
mesh = sharded_meta_param.device_mesh
dist.broadcast(full_param, src=0, group=mesh.get_group(0))
sharded_tensor = distribute_tensor(
@ -774,7 +797,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
for param_name, sharded_meta_param in meta_sharded_sd.items():
full_tensor = torch.empty(
sharded_meta_param.size(),
device="cuda",
device=device_type.type,
dtype=sharded_meta_param.dtype,
)
mesh = sharded_meta_param.device_mesh
@ -787,7 +810,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
model.load_state_dict(sharded_sd, assign=True)
for param in model.parameters():
self.assertIsInstance(param, DTensor)
self.assertEqual(param.device.type, "cuda")
self.assertEqual(param.device.type, device_type.type)
# Construct the reference model on nonzero ranks by broadcasting the
# unsharded model from rank 0 and sharding on all ranks
@ -807,7 +830,7 @@ class TestFullyShardMetaDeviceInit(FSDPTestMultiThread):
self.assertEqual(param, ref_param)
# Check one forward/backward for parity
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
loss = model(inp).sum()
loss.backward()
ref_loss = ref_model(inp).sum()
@ -822,20 +845,22 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
def world_size(self) -> int:
return 4
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_1d_process_group_init(self):
assert self.world_size == 4, f"{self.world_size}"
# For convenience, use device mesh's infra to construct the DP PG
# (in practice, the trainer would do it manually via `new_group()`)
dp_size = 2
global_mesh = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
device_type.type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
ref_dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
dp_pg = ref_dp_mesh.get_group(0)
# Check the `from_group()` API for correctness
dp_mesh = DeviceMesh.from_group(dp_pg, "cuda", mesh_dim_names=("dp",))
dp_mesh = DeviceMesh.from_group(dp_pg, device_type.type, mesh_dim_names=("dp",))
# Only compare the mesh tensors, not `DeviceMesh` objects themselves,
# since the ref has a parent mesh, while the `from_group` one does not
self.assertEqual(dp_mesh.mesh, ref_dp_mesh.mesh)
@ -860,7 +885,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
fully_shard(module, mesh=dp_mesh)
# Ensure that TP ranks have the same input
inp = torch.randn((4, mlp_dim), device="cuda")
inp = torch.randn((4, mlp_dim), device=device_type.type)
if self.rank in (0, 1):
dist.broadcast(inp, src=0, group=tp_mesh.get_group(0))
elif self.rank in (2, 3):
@ -882,7 +907,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
param.grad.device_mesh.mesh, ref_param.grad.device_mesh.mesh
)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_2d_process_group_init(self):
shard_mesh_dim_size = 2
assert (
@ -891,7 +916,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size
mesh_dim_names = ("replicate", "shard")
ref_mesh = init_device_mesh(
"cuda",
device_type.type,
(replicate_mesh_dim_size, shard_mesh_dim_size),
mesh_dim_names=mesh_dim_names,
)
@ -910,7 +935,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
# Check the `from_group()` API for correctness
mesh = DeviceMesh.from_group(
[dp_replicate_group, dp_shard_group],
"cuda",
device_type.type,
mesh_dim_names=mesh_dim_names,
mesh=mesh_tensor,
)
@ -943,7 +968,7 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread):
for module in (model.in_proj, model.out_proj, model):
fully_shard(module, mesh=mesh)
inp = torch.randn((4, mlp_dim), device="cuda")
inp = torch.randn((4, mlp_dim), device=device_type.type)
ref_loss = ref_model(inp).sum()
ref_loss.backward()
loss = model(inp).sum()
@ -959,11 +984,13 @@ class TestFullyShardHSDPBroadcast(FSDPTestMultiThread):
def world_size(self) -> int:
return 4
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_hsdp_broadcast_across_replicas(self):
shard_size, replicate_size = 2, 2
mesh = init_device_mesh(
"cuda", (replicate_size, shard_size), mesh_dim_names=("replicate", "shard")
device_type.type,
(replicate_size, shard_size),
mesh_dim_names=("replicate", "shard"),
)
model_args = ModelArgs()
model = Transformer(model_args)
@ -1017,7 +1044,7 @@ class TestFullyShardHSDPBroadcast(FSDPTestMultiThread):
self.assertEqual(other_local_tensor, local_tensor_list[0])
# Check that we can run an iteration without erroring
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
model(inp).sum().backward()
@ -1028,17 +1055,17 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
def perThreadSetUp(self) -> None:
super().perThreadSetUp()
torch.set_default_device("cuda")
torch.set_default_device(device_type)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_custom_hook_custom_stream(self):
hsdp_mesh = init_device_mesh(
"cuda", (2, 2), mesh_dim_names=("replicate", "shard")
device_type.type, (2, 2), mesh_dim_names=("replicate", "shard")
)
model = MLP(10, bias=False)
fully_shard(model, mesh=hsdp_mesh)
model = cast(FSDPModule, model)
custom_stream = torch.cuda.Stream()
custom_stream = torch.get_device_module(device_type).Stream()
# native HSDP should reject
with self.assertRaises(ValueError) as cm:
@ -1051,7 +1078,7 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
intra_pg = _init_intra_node_process_group(2)
fsdp_mesh = DeviceMesh.from_group(
intra_pg,
"cuda",
device_type.type,
dist.get_process_group_ranks(intra_pg),
mesh_dim_names=("shard",),
)
@ -1059,7 +1086,7 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
def _hook(_output: torch.Tensor) -> None:
nonlocal hook_used_stream
hook_used_stream = torch.cuda.current_stream()
hook_used_stream = torch.get_device_module(device_type).current_stream()
model = MLP(10, bias=False)
fully_shard(model, mesh=fsdp_mesh)
@ -1069,17 +1096,17 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
inp = torch.arange(10, dtype=torch.float32, requires_grad=True).view(1, 10)
out = model(inp)
out.sum().backward()
torch.cuda.synchronize()
torch.get_device_module(device_type).synchronize()
self.assertEqual(hook_used_stream, custom_stream)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_custom_hsdp_all_reduce_hook(self):
world_pg = dist.distributed_c10d._get_default_group()
intra_pg = _init_intra_node_process_group(2)
inter_pg = _init_inter_node_process_group(world_pg, 2)
mesh = DeviceMesh.from_group(
intra_pg,
"cuda",
device_type.type,
dist.get_process_group_ranks(intra_pg),
mesh_dim_names=("shard",),
)
@ -1106,7 +1133,7 @@ class TestHSDPWithCustomHook(FSDPTestMultiThread):
inp = torch.arange(10, dtype=torch.float32, requires_grad=True).view(1, 10)
out = model(inp)
out.sum().backward()
torch.cuda.synchronize()
torch.get_device_module(device_type).synchronize()
# custom hook was fired
self.assertTrue(hook_called)
# within each replica, FSDP shards the weights at dim 0
@ -1140,7 +1167,7 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
ref_model = copy.deepcopy(model)
return model, ref_model
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_init_1d_transformer_shard_largest_dim(self):
model, ref_model = self._init_models()
@ -1168,7 +1195,7 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
full_param = param.full_tensor()
self.assertEqual(full_param, ref_param)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_init_1d_transformer_shard_dim_neg1(self):
model, ref_model = self._init_models()
@ -1184,13 +1211,13 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
full_param = param.full_tensor()
self.assertEqual(full_param, ref_param)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_init_2d_transformer_shard_diff_dim(self):
model, ref_model = self._init_models()
dp_size, tp_size = self.world_size // 2, 2
global_mesh = init_device_mesh(
"cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")
device_type.type, (dp_size, tp_size), mesh_dim_names=("dp", "tp")
)
model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True)
@ -1234,7 +1261,7 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
full_param = param.full_tensor()
self.assertEqual(full_param, ref_param)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_init_1d_uneven_shard_largest_dim(self):
torch.manual_seed(42)
model = nn.Sequential(nn.Linear(16, 17), nn.Linear(17, 8))
@ -1255,7 +1282,7 @@ class TestFullyShardShardPlacementFn(FSDPTestMultiThread):
):
fully_shard(model, shard_placement_fn=shard_placement_fn)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_invalid_shard_dim(self):
model = nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 8))
@ -1276,7 +1303,7 @@ class TestFullyShardOldImport(FSDPTestMultiThread):
def world_size(self) -> int:
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_old_import_training(self):
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import FSDPModule
@ -1291,7 +1318,7 @@ class TestFullyShardOldImport(FSDPTestMultiThread):
self.assertIsInstance(model[1], FSDPModule)
self.assertIsInstance(model, FSDPModule)
inp = torch.randn((8, 16), device="cuda")
inp = torch.randn((8, 16), device=device_type)
model(inp).sum().backward()

View File

@ -15,6 +15,12 @@ requires_distributed = functools.partial(
unittest.skipIf, not dist.is_available(), "requires distributed"
)
import torch
from torch.testing._internal.common_fsdp import get_devtype
device_type = torch.device(get_devtype())
@skip_if_lt_x_gpu(2)
class LoggingTests(LoggingTestCase):
@ -27,7 +33,7 @@ class LoggingTests(LoggingTestCase):
env["MASTER_PORT"] = "34715"
env["MASTER_ADDR"] = "localhost"
_, stderr = self.run_process_no_exception(
"""\
f"""\
import logging
import torch
import torch.distributed as dist
@ -35,7 +41,7 @@ import torch.nn as nn
from torch.distributed.fsdp import fully_shard
logger = logging.getLogger("torch.distributed._composable.fsdp")
logger.setLevel(logging.DEBUG)
device = "cuda"
device = {device_type.type}
torch.manual_seed(0)
model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)])
for layer in model:

View File

@ -2,12 +2,13 @@
import functools
import gc
import unittest
import torch
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, OffloadPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_utils import run_tests, TEST_CUDA, TEST_HPU
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
@ -15,12 +16,16 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
device_type = torch.device(get_devtype())
class TestFullyShardMemory(FSDPTest):
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())
return min(2, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, " 'empty_cache' is not supported on hpu")
def test_fully_shard_training_memory(self):
self.run_subtests(
{
@ -56,10 +61,10 @@ class TestFullyShardMemory(FSDPTest):
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
# allocate the cuBLAS workspaces before measuring the memory usage
# since the workspace size can differ between hardwares
lin = torch.nn.Linear(768, 768, device="cuda")
inp = torch.randn(1, 768, device="cuda")
lin = torch.nn.Linear(768, 768, device=device_type)
inp = torch.randn(1, 768, device=device_type)
lin(inp).sum().backward()
torch.cuda.empty_cache()
torch.get_device_module(device_type).empty_cache()
base_mem_mb = self._get_peak_active_memory_mb()
vocab_size = 32
model_args = ModelArgs(
@ -108,7 +113,7 @@ class TestFullyShardMemory(FSDPTest):
self.assertLessEqual(curr_mem_mb - base_mem_mb, init_mem_mb)
# Use a small input to minimize activation memory usage
inp = torch.randint(0, vocab_size, (1, 4), device="cuda")
inp = torch.randint(0, vocab_size, (1, 4), device=device_type.type)
# Forward:
loss = model(inp)
@ -169,7 +174,7 @@ class TestFullyShardMemory(FSDPTest):
) * 4 / 1e6 + buffer_mb
self.assertLessEqual(mem_mb - base_mem_mb, expected_mem_mb)
del loss
torch.cuda.reset_peak_memory_stats()
torch.get_device_module(device_type).reset_peak_memory_stats()
# Optimizer step: unsharded parameters/gradients freed
if not run_optim_in_backward:
@ -187,7 +192,9 @@ class TestFullyShardMemory(FSDPTest):
# Zero grad: sharded gradients freed
if not run_optim_in_backward:
optim.zero_grad()
torch.cuda.reset_peak_memory_stats() # reset after freeing
torch.get_device_module(
device_type
).reset_peak_memory_stats() # reset after freeing
mem_mb = self._get_peak_active_memory_mb()
expected_mem_mb = 0
if not use_cpu_offload:
@ -228,12 +235,18 @@ class TestFullyShardMemory(FSDPTest):
self.assertEqual(mem_mb, base_mem_mb)
def _get_peak_active_memory_mb(self) -> int:
mem_stats = torch.cuda.memory_stats()
return round(mem_stats["active_bytes.all.peak"] / 1e6)
mem_stats = torch.get_device_module(device_type).memory_stats()
if TEST_CUDA:
return round(mem_stats["active_bytes.all.peak"] / 1e6)
if TEST_HPU:
return round(mem_stats["MaxInUse"] / 1e6)
def _get_curr_active_memory_mb(self) -> int:
mem_stats = torch.cuda.memory_stats()
return round(mem_stats["active_bytes.all.current"] / 1e6)
mem_stats = torch.get_device_module(device_type).memory_stats()
if TEST_CUDA:
return round(mem_stats["active_bytes.all.current"] / 1e6)
if TEST_HPU:
return round(mem_stats["InUse"] / 1e6)
def _register_optim_in_backward(
self, model: torch.nn.Module, **optim_kwargs

View File

@ -22,17 +22,21 @@ from torch.testing._internal.common_fsdp import (
check_sharded_parity,
FSDPTest,
FSDPTestMultiThread,
get_devtype,
MLP,
patch_reduce_scatter,
reduce_scatter_with_assert,
)
from torch.testing._internal.common_utils import run_tests, skipIfRocm
from torch.testing._internal.common_utils import run_tests, skipIfRocm, TEST_HPU
device_type = torch.device(get_devtype())
class TestFullyShardMixedPrecisionTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.get_device_module(device_type).device_count())
def _init_models_and_optims(
self,
@ -43,7 +47,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
):
torch.manual_seed(42)
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
@ -123,7 +127,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
@ -209,7 +213,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
@ -258,7 +262,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((4, 16), device="cuda", dtype=param_dtype)
inp = torch.randn((4, 16), device=device_type.type, dtype=param_dtype)
for iter_idx in range(10):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = model(inp).sum()
@ -309,7 +313,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
# To emulate the mixed precision implementation where forward/backward
# compute use bf16 and optimizer uses fp32, we maintain both an fp32
# and a bf16 copy of the reference model
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for mlp in model:
@ -329,7 +333,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest):
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
torch.manual_seed(42 + self.rank + 1)
device = torch.device("cuda")
device = device_type
# Train on the same input to avoid loss explosion
num_microbatches = 4
inp = torch.randn((2 * num_microbatches, 16), device=device, dtype=param_dtype)
@ -389,7 +393,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
@skip_if_lt_x_gpu(1)
def test_float16_on_one_submodule(self):
x = torch.zeros(2, 100, device="cuda")
x = torch.zeros(2, 100, device=device_type)
# Subtest 1: use fp16 on the second child submodule -- does not require
# any additional casting logic
@ -397,7 +401,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
model = SaveForwardInputsModel(
forward_inputs,
cast_forward_inputs=False,
).cuda()
).to(device_type)
fully_shard(model.c2, mp_policy=MixedPrecisionPolicy(param_dtype=torch.float16))
fully_shard(model)
model(x).sum().backward()
@ -410,7 +414,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=True
).cuda()
).to(device_type)
fully_shard(
model.c2,
mp_policy=MixedPrecisionPolicy(
@ -428,7 +432,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
forward_inputs: dict[nn.Module, torch.Tensor] = {}
model = SaveForwardInputsModel(
forward_inputs=forward_inputs, cast_forward_inputs=False
).cuda()
).to(device_type)
fully_shard(
model.c1,
mp_policy=MixedPrecisionPolicy(
@ -470,13 +474,13 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.forward_inputs["model_input_x"] = x
y = torch.ones(
2, 100, device="cuda", dtype=torch.float32
2, 100, device=device_type.type, dtype=torch.float32
) # external input
return self.l2(self.l1(x), y)
forward_inputs: dict[str, torch.Tensor] = {}
model = ToyModel(forward_inputs).cuda()
x = torch.zeros(2, 100, device="cuda", dtype=torch.float32)
model = ToyModel(forward_inputs).to(device_type)
x = torch.zeros(2, 100, device=device_type.type, dtype=torch.float32)
fully_shard(
model.l2,
mp_policy=MixedPrecisionPolicy(
@ -529,9 +533,15 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
for module in (model[0], model[1], model[2], model):
fully_shard(module, mp_policy=mp_policy)
with self.assertRaisesRegex(RuntimeError, "Expected running_mean to have type"):
# Errors in batch norm 2D backward
if TEST_HPU:
inner(model, torch.randn((3, 1, 9, 9)))
else:
with self.assertRaisesRegex(
RuntimeError,
"Expected running_mean to have type", # Error not seen on HPUs and hence it can be skipped
):
# Errors in batch norm 2D backward
inner(model, torch.randn((3, 1, 9, 9)))
# Batch norm 2D: cast buffers down to lower precision
model = nn.Sequential(nn.Conv2d(1, 5, 3), nn.BatchNorm2d(5), nn.Conv2d(5, 4, 3))
@ -557,7 +567,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
model = nn.Sequential(
nn.Linear(32, 32, dtype=init_dtype),
nn.Linear(32, 32, dtype=init_dtype),
)
).to(device_type.type)
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16
)
@ -579,7 +589,7 @@ class TestFullyShardMixedPrecisionCasts(FSDPTestMultiThread):
reduce_scatter_with_assert, self, orig_reduce_scatter, assert_fn
)
with patch_reduce_scatter(reduce_scatter):
inp = torch.randn((4, 32), device="cuda")
inp = torch.randn((4, 32), device=device_type.type)
loss = model(inp).sum()
loss.backward()

View File

@ -2,6 +2,7 @@
import copy
import functools
import unittest
from typing import Callable
import torch
@ -12,10 +13,15 @@ from torch.distributed.tensor.experimental import implicit_replication
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
get_devtype,
patch_all_gather,
patch_reduce_scatter,
)
from torch.testing._internal.common_utils import get_cycles_per_ms, run_tests
from torch.testing._internal.common_utils import get_cycles_per_ms, run_tests, TEST_HPU
device_type = torch.device(get_devtype())
device_module = torch.get_device_module(device_type)
class TestFullyShardOverlap(FSDPTest):
@ -35,9 +41,10 @@ class TestFullyShardOverlap(FSDPTest):
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())
return min(2, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
def test_fully_shard_training_overlap(self):
torch.manual_seed(42)
@ -46,7 +53,7 @@ class TestFullyShardOverlap(FSDPTest):
model = nn.Sequential(
*[LinearWithSleep(dim, compute_sleep_ms) for _ in range(num_linears)]
)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
for lin in model:
assert len(list(lin.parameters())) == 1, "Expects only one weight"
fully_shard(lin, reshard_after_forward=True)
@ -54,15 +61,21 @@ class TestFullyShardOverlap(FSDPTest):
orig_all_gather_into_tensor = dist.all_gather_into_tensor
orig_reduce_scatter_tensor = dist.reduce_scatter_tensor
comm_stream = torch.cuda.Stream()
comm_stream = torch.get_device_module(device_type).Stream()
def delay_collective():
# Share a stream so that all-gather and reduce-scatter block each
# other like in `ProcessGroupNCCL`
comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(comm_stream):
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
torch.cuda.current_stream().wait_stream(comm_stream)
comm_stream.wait_stream(
torch.get_device_module(device_type).current_stream()
)
with torch.get_device_module(device_type).stream(comm_stream):
torch.get_device_module(device_type)._sleep(
int(comm_sleep_ms * get_cycles_per_ms())
)
torch.get_device_module(device_type).current_stream().wait_stream(
comm_stream
)
def delayed_all_gather(*args, **kwargs):
delay_collective()
@ -72,7 +85,7 @@ class TestFullyShardOverlap(FSDPTest):
delay_collective()
return orig_reduce_scatter_tensor(*args, **kwargs)
inp = torch.randn((2, dim), device="cuda")
inp = torch.randn((2, dim), device=device_type.type)
loss = model(inp).sum() # warmup CUDA and allocator
loss.backward()
@ -144,6 +157,7 @@ class TestFullyShardOverlap(FSDPTest):
self.assertLessEqual(fwd_bwd_time, ref_fwd_bwd_time)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
def test_fully_shard_post_optim_event_overlap(self):
torch.manual_seed(42)
@ -153,17 +167,19 @@ class TestFullyShardOverlap(FSDPTest):
# low-compute linear, where only the low-compute linear uses FSDP
model = nn.Sequential(
LinearWithSleep(dim, compute_sleep_ms), nn.Linear(dim, dim)
).cuda()
).to(device_type)
fully_shard(model[1], reshard_after_forward=False)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
orig_all_gather_into_tensor = dist.all_gather_into_tensor
def delayed_all_gather(*args, **kwargs):
torch.cuda._sleep(int(comm_sleep_ms * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(
int(comm_sleep_ms * get_cycles_per_ms())
)
return orig_all_gather_into_tensor(*args, **kwargs)
inp = torch.randn((2, dim), device="cuda")
inp = torch.randn((2, dim), device=device_type)
def run_train_steps(num_iters: int, use_post_optim_event: bool):
for _ in range(num_iters):
@ -174,7 +190,11 @@ class TestFullyShardOverlap(FSDPTest):
with implicit_replication():
optim.step()
if use_post_optim_event:
post_optim_event = torch.cuda.current_stream().record_event()
post_optim_event = (
torch.get_device_module(device_type)
.current_stream()
.record_event()
)
model[1].set_post_optim_event(post_optim_event)
run_train_steps(1, False) # warmup CUDA and allocator
@ -205,14 +225,14 @@ class TestFullyShardOverlap(FSDPTest):
self.assertGreater(baseline_time, test_time)
def _time_fn(self, fn: Callable):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event = device_module.Event(enable_timing=True)
end_event = device_module.Event(enable_timing=True)
dist.barrier()
torch.cuda.synchronize()
device_module.synchronize()
start_event.record()
fn()
end_event.record()
torch.cuda.synchronize()
device_module.synchronize()
elapsed_time = start_event.elapsed_time(end_event)
return elapsed_time
@ -223,13 +243,15 @@ class Matmul(torch.autograd.Function):
def forward(ctx, input: torch.Tensor, weight: torch.Tensor, sleep_ms: int):
ctx.save_for_backward(input, weight)
ctx.sleep_ms = sleep_ms
torch.cuda._sleep(int(sleep_ms * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(int(sleep_ms * get_cycles_per_ms()))
return input @ weight
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
(input, weight) = ctx.saved_tensors
torch.cuda._sleep(int(2 * ctx.sleep_ms * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(
int(2 * ctx.sleep_ms * get_cycles_per_ms())
)
grad_input = grad_output @ weight.T
grad_weight = input.T @ grad_output
return grad_input, grad_weight, None

View File

@ -1,11 +1,10 @@
# Owner(s): ["oncall: distributed"]
import copy
import unittest
import torch.nn as nn
from torch.distributed.fsdp import FSDPModule, fully_shard
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTestMultiThread, MLP
from torch.testing._internal.common_utils import run_tests
@ -15,7 +14,7 @@ class TestFullyShardState(FSDPTestMultiThread):
def world_size(self) -> int:
return 1
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_state(self):
"""
Tests the ability to get the state object from a fully sharded module.
@ -31,7 +30,7 @@ class TestFullyShardState(FSDPTestMultiThread):
# Check that each `fully_shard` call constructs a distinct state object
self.assertEqual(len(set(all_states)), num_mlps + 1)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_reapply(self):
model = MLP(8)
fully_shard(model)
@ -41,7 +40,7 @@ class TestFullyShardState(FSDPTestMultiThread):
):
fully_shard(model)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_cls(self):
# Check that we only swap class for the module passed to `fully_shard`
model = MLP(8)
@ -64,7 +63,7 @@ class TestFullyShardState(FSDPTestMultiThread):
self.assertTrue(isinstance(sliced_model, nn.Sequential))
self.assertFalse(isinstance(sliced_model, FSDPModule))
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_unsupported_module_cls(self):
regex = (
r"fully\_shard does not support containers that do not implement forward"
@ -76,7 +75,7 @@ class TestFullyShardState(FSDPTestMultiThread):
with self.assertRaisesRegex(ValueError, regex):
fully_shard(model)
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_fully_shard_deepcopy(self):
model = MLP(8)
fully_shard(model)

View File

@ -2,7 +2,6 @@
import copy
import functools
import unittest
from contextlib import nullcontext
from typing import Optional
@ -16,9 +15,13 @@ from torch.distributed.tensor.parallel import (
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, FSDPTestMultiThread, MLP
from torch.testing._internal.common_fsdp import (
FSDPTest,
FSDPTestMultiThread,
get_devtype,
MLP,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
@ -27,14 +30,17 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
)
device_type = torch.device(get_devtype())
class TestFullyShardStateDictMultiProcess(FSDPTest):
@property
def world_size(self) -> int:
return min(8, torch.cuda.device_count())
return min(8, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_dp_state_dict_save_load(self):
fsdp_mesh = init_device_mesh("cuda", (self.world_size,))
fsdp_mesh = init_device_mesh(device_type.type, (self.world_size,))
self.run_subtests(
{"mlp_dim": [2, 3, 4, 5], "mesh": [fsdp_mesh]},
self._test_dp_state_dict_save_load,
@ -53,7 +59,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
if self.world_size % 2 != 0:
return
hsdp_mesh = init_device_mesh(
"cuda",
device_type.type,
(self.world_size // 2, 2),
mesh_dim_names=("dp_replicate", "dp_shard"),
)
@ -103,7 +109,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
fully_shard_fn(model2, reshard_after_forward=False)
self._test_state_dict_save_load(model2)
ref_sharded_sd = model2.state_dict()
inp = torch.randn((2, mlp_dim), device="cuda")
inp = torch.randn((2, mlp_dim), device=device_type.type)
model2(inp) # parameters are not resharded after this forward
# Check that state dict hooks reshard
sharded_sd = model2.state_dict()
@ -155,12 +161,12 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
model.load_state_dict(sd, assign=True, strict=False)
# lazy init without error
inp = torch.rand((mlp_dim, mlp_dim), device="cuda")
inp = torch.rand((mlp_dim, mlp_dim), device=device_type.type)
context = (
self.assertRaisesRegex(
RuntimeError,
r"Found following parameters on non-CPU device: \[\('0.weight', device\(type='cuda'",
rf"Found following parameters on non-CPU device: \[\('0.weight', device\(type='{device_type.type}'",
)
if not cpu_state_dict
else nullcontext()
@ -171,10 +177,13 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
for name, dtensor in state_dict.items():
self.assertEqual(dtensor.device.type, "cpu")
@skip_if_lt_x_gpu(2)
def test_2d_state_dict_correctness(self):
dp_size = 2
global_mesh = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
device_type.type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
torch.manual_seed(42)
@ -214,7 +223,9 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
def test_dp_tp_state_dict_save_load(self):
dp_size = 2
global_mesh = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
device_type.type,
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
self.run_subtests(
{"mlp_dim": [4, 6, 8, 10]},
@ -245,7 +256,7 @@ class TestFullyShardStateDictMultiProcess(FSDPTest):
@skip_if_lt_x_gpu(4)
def test_hsdp_tp_state_dict_save_load(self):
global_mesh = init_device_mesh(
"cuda",
device_type.type,
(2, 2, self.world_size // 4),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
)
@ -345,12 +356,12 @@ class TestFullyShardStateDictMultiThread(FSDPTestMultiThread):
def world_size(self):
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_rank0_offload_full_state_dict(self):
# Construct a reference unsharded model on all ranks
model_args = ModelArgs(dropout_p=0.0)
torch.manual_seed(42)
ref_model = Transformer(model_args).cuda()
ref_model = Transformer(model_args).to(device_type)
for param in ref_model.parameters():
torch.distributed.broadcast(param.detach(), src=0)

View File

@ -27,7 +27,6 @@ from torch.distributed.fsdp import (
)
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
from torch.distributed.tensor.debug import CommDebugMode
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
@ -42,6 +41,7 @@ from torch.testing._internal.common_fsdp import (
from torch.testing._internal.common_utils import (
get_cycles_per_ms,
run_tests,
TEST_HPU,
wrapSwapTensorsTest,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
@ -54,15 +54,20 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
c10d_ops = torch.ops.c10d
funcol = torch.ops.c10d_functional
from torch.testing._internal.common_fsdp import get_devtype
device_type = torch.device(get_devtype())
class TestFullyShardForwardInputs(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_root_move_forward_input_to_device(self):
device = torch.device("cuda", 0)
device = torch.device(device_type.type, 0)
class ParamlessModule(nn.Module):
def forward(self, x: torch.Tensor, ys: tuple[torch.Tensor, ...]):
@ -78,8 +83,8 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread):
y = ys[0] + ys[1]
return x + y + 1
model = ParamlessModule()
fully_shard(model)
model = ParamlessModule().to(device)
fully_shard(model).to(device)
x = torch.randn((3,))
ys = (torch.randn((3,)), torch.randn((3,)))
self.assertEqual(x.device, torch.device("cpu"))
@ -93,10 +98,10 @@ class TestFullyShardRegisteredParams(FSDPTestMultiThread):
def world_size(self) -> int:
return 4
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_param_registration_after_forward(self):
"""Tests the parameter registration after forward."""
device = torch.device("cuda", 0)
device = torch.device(device_type.type, 0)
# Single FSDP group
for reshard_after_forward in (True, False, 2):
torch.manual_seed(42)
@ -107,7 +112,7 @@ class TestFullyShardRegisteredParams(FSDPTestMultiThread):
dist.broadcast(param, src=0)
ref_model = copy.deepcopy(model)
fully_shard(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 3), device="cuda")
inp = torch.randn((2, 3), device=device_type.type)
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
model(inp) # root does not reshard after forward
@ -147,15 +152,15 @@ class TestFullyShardRegisteredParams(FSDPTestMultiThread):
self._assert_dtensor_params(model.parameters())
self._assert_same_params(model.parameters(), ref_model.parameters())
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_param_registration_after_backward(self):
"""Tests the parameter registration after backward."""
device = torch.device("cuda", 0)
device = torch.device(device_type.type, 0)
# Single FSDP group
for reshard_after_forward in (True, False, 2):
model = MLP(8, device)
fully_shard(model, reshard_after_forward=reshard_after_forward) # root only
inp = torch.randn((2, 8), device="cuda")
inp = torch.randn((2, 8), device=device_type.type)
self._assert_dtensor_params(model.parameters())
model(inp).sum().backward()
self._assert_dtensor_params(model.parameters())
@ -198,14 +203,14 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
def world_size(self) -> int:
return 2
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
@wrapSwapTensorsTest(True)
def test_to_float64_after_init(self):
"""Tests that the user can cast the module to float64 after init."""
# NOTE: Test fp64 instead of a lower precision dtype like bf16 for
# better numerics. The important part is changing the dtype.
torch.manual_seed(42)
mlp_dim, device, dtype = 4, torch.device("cuda"), torch.float64
mlp_dim, device, dtype = 4, device_type, torch.float64
model = MLP(mlp_dim, device=device)
for param in model.parameters():
dist.broadcast(param, src=0)
@ -222,7 +227,7 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
check_sharded_parity(self, ref_model, model)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, mlp_dim), device="cuda", dtype=dtype)
inp = torch.randn((2, mlp_dim), device=device_type.type, dtype=dtype)
for iter_idx in range(10):
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
@ -245,7 +250,7 @@ class TestFullyShardCastAfterInit(FSDPTestMultiThread):
class TestFullyShard1DTrainingCore(FSDPTest):
@property
def world_size(self) -> int:
return min(8, torch.cuda.device_count())
return min(8, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_train_parity_single_group_shard_dim0(self):
@ -287,7 +292,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
model = nn.Sequential(
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
replicate(ref_model, device_ids=[self.rank])
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
@ -298,7 +303,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
fully_shard(model, shard_placement_fn=shard_placement_fn)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank + 1)
inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),)
inp = (torch.randn((4, lin_shapes[0][0]), device=device_type.type),)
for iter_idx in range(10):
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
@ -309,6 +314,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
self.assertEqual(losses[0], losses[1])
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep kernel not supported for HPU")
@compiled_fsdp_test(compile_compute_on_module=Transformer)
def test_train_parity_multi_group(self):
"""
@ -319,7 +325,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
self.run_subtests(
{
"reshard_after_forward": [True, False, 2],
"device_type": ["cuda"],
"device_type": [device_type.type],
"offload_policy": [OffloadPolicy()],
"delay_after_forward": [False, True],
"delay_before_all_gather": [False, True],
@ -331,6 +337,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "sleep kernel not supported on HPU")
def test_train_parity_multi_group_cpu_offload_eager(self):
"""
Tests train parity against DDP when using multiple parameter groups for
@ -343,7 +350,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
CPUOffloadPolicy(pin_memory=True),
CPUOffloadPolicy(pin_memory=False),
],
"device_type": ["cuda"],
"device_type": [device_type.type],
"delay_after_forward": [False, True],
"delay_before_all_gather": [False, True],
"delay_before_reduce_scatter": [False, True],
@ -354,6 +361,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "sleep kernel not supported on HPU")
@compiled_fsdp_test(compile_compute_on_module=Transformer)
def test_train_parity_multi_group_unshard_async_op(self):
"""
@ -363,7 +371,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
self.run_subtests(
{
"reshard_after_forward": [True],
"device_type": ["cuda"],
"device_type": [device_type.type],
"offload_policy": [OffloadPolicy()],
"delay_after_forward": [False, True],
"delay_before_all_gather": [False, True],
@ -394,7 +402,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
in (2, 3)
):
return
assert device_type in ("cuda", "cpu"), f"{device_type}"
assert device_type in ("cuda", "hpu", "xpu", "cpu"), f"{device_type}"
torch.manual_seed(42)
vocab_size = 1024
model_args = ModelArgs(
@ -406,8 +414,11 @@ class TestFullyShard1DTrainingCore(FSDPTest):
)
model = Transformer(model_args)
ref_model = copy.deepcopy(model)
if device_type == "cuda":
replicate(ref_model.cuda(), device_ids=[self.rank])
if device_type == device_type:
replicate(
ref_model.to(device_type),
device_ids=[self.rank],
)
else:
gloo_pg = dist.new_group(backend="gloo")
replicate(ref_model, process_group=gloo_pg)
@ -432,11 +443,15 @@ class TestFullyShard1DTrainingCore(FSDPTest):
orig_reduce_scatter = dist.reduce_scatter_tensor
def delayed_all_gather(*args, **kwargs):
torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(
int(delay_in_ms * get_cycles_per_ms())
)
return orig_all_gather(*args, **kwargs)
def delayed_reduce_scatter(*args, **kwargs):
torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(
int(delay_in_ms * get_cycles_per_ms())
)
return orig_reduce_scatter(*args, **kwargs)
torch.manual_seed(42 + self.rank + 1)
@ -458,10 +473,14 @@ class TestFullyShard1DTrainingCore(FSDPTest):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
if _model is model and delay_after_forward:
torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(
int(delay_in_ms * get_cycles_per_ms())
)
losses[-1].backward()
if _model is model and delay_before_optim:
torch.cuda._sleep(int(delay_in_ms * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(
int(delay_in_ms * get_cycles_per_ms())
)
_optim.step()
self.assertEqual(losses[0], losses[1])
@ -474,14 +493,14 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42)
lin_dim = 32
model = nn.Sequential(*[MLP(lin_dim, torch.device("cpu")) for _ in range(3)])
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for mlp in model:
fully_shard(mlp)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
torch.manual_seed(42 + self.rank)
inp = torch.randn((8, lin_dim), device=torch.device("cuda"))
inp = torch.randn((8, lin_dim), device=device_type)
ref_root_loss = ref_model(inp).sum()
ref_root_loss.backward()
@ -500,7 +519,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
root_loss = model(inp).sum()
root_loss.backward()
torch.cuda._sleep(int(100 * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(int(100 * get_cycles_per_ms()))
optim.step()
optim.zero_grad()
nonroot_loss = model[0](inp).sum()
@ -535,16 +554,19 @@ class TestFullyShard1DTrainingCore(FSDPTest):
return self.outer(i + j)
torch.manual_seed(42)
model = MultiForwardModule(device="cuda")
model = MultiForwardModule(device=device_type.type)
ref_model = copy.deepcopy(model)
replicate(ref_model, device_ids=[self.rank])
replicate(
ref_model,
device_ids=[self.rank],
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
fully_shard(model.inner)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.manual_seed(42 + self.rank)
inp = torch.randn((32, 4), device="cuda")
inp = torch.randn((32, 4), device=device_type.type)
for iter_idx in range(10):
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
@ -559,7 +581,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=8, dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_model = replicate(copy.deepcopy(model).to(device_type))
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for layer in itertools.chain(model.layers, [model]):
fully_shard(layer)
@ -582,7 +604,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 8), device=device_type.type)
for _ in range(10):
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
@ -593,11 +615,12 @@ class TestFullyShard1DTrainingCore(FSDPTest):
self.assertEqual(losses[0], losses[1])
@skip_if_lt_x_gpu(2)
@unittest.skipIf(TEST_HPU, "Sleep is not supported on HPU")
def test_post_optim_event(self):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model).cuda())
ref_model = replicate(copy.deepcopy(model).to(device_type.type))
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for layer in itertools.chain(model.layers, [model]):
fully_shard(layer)
@ -606,13 +629,15 @@ class TestFullyShard1DTrainingCore(FSDPTest):
def step_post_hook(
fsdp_module: FSDPModule, opt: torch.optim.Optimizer, args, kwargs
) -> None:
post_optim_event = torch.cuda.current_stream().record_event()
post_optim_event = (
torch.get_device_module(device_type).current_stream().record_event()
)
fsdp_module.set_post_optim_event(post_optim_event)
optim.register_step_post_hook(functools.partial(step_post_hook, model))
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 8), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 8), device=device_type.type)
# Track all losses and check for equality at the end to avoid a CPU
# sync point after each iteration
ref_losses: list[torch.Tensor] = []
@ -629,7 +654,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
optim.step()
# Sleep after the optimizer step to allow CPU to run ahead into the
# next iteration's forward, exercising the post-optim stream sync
torch.cuda._sleep(int(25 * get_cycles_per_ms()))
torch.get_device_module(device_type)._sleep(int(25 * get_cycles_per_ms()))
for ref_loss, loss in zip(ref_losses, losses):
self.assertEqual(ref_loss, loss)
@ -639,7 +664,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
def world_size(self) -> int:
# Since these tests run with a larger transformer model, they may see
# some numeric drift with >2 GPUs
return min(torch.cuda.device_count(), 2)
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
@compiled_fsdp_test(compile_compute_on_module=Transformer)
@ -669,7 +694,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
return
torch.manual_seed(42)
vocab_size = 1024
with torch.device(torch.device("cuda")):
with torch.device(device_type):
model_args = ModelArgs(
n_layers=3,
n_heads=4,
@ -683,7 +708,10 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
weight_tying=module_grouping != "mem_eff",
)
model = Transformer(model_args)
ref_model = replicate(copy.deepcopy(model), device_ids=[self.rank])
ref_model = replicate(
copy.deepcopy(model),
device_ids=[self.rank],
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
# Apply activation checkpointing
@ -723,7 +751,7 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
torch.manual_seed(42 + self.rank)
# Reuse the same input across iterations to avoid loss explosion from
# trying to learn from random inputs
inp = torch.randint(0, vocab_size, (3, 64), device="cuda")
inp = torch.randint(0, vocab_size, (3, 64), device=device_type.type)
check_sharded_parity(
self, ref_model, model, prefixes_to_ignore=prefixes_to_ignore
)
@ -750,14 +778,14 @@ class TestFullyShard1DTrainingCompose(FSDPTest):
class TestFullyShardShardPlacementFnMultiProcess(FSDPTest):
@property
def world_size(self) -> int:
return min(8, torch.cuda.device_count())
return min(8, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_train_parity_shard_placement_fn_shard_largest_dim(self):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=3, dropout_p=0.0)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
@ -773,7 +801,7 @@ class TestFullyShardShardPlacementFnMultiProcess(FSDPTest):
self.assertEqual(full_param, ref_param)
torch.manual_seed(42 + self.rank)
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type)
for iter_idx in range(5):
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
@ -800,7 +828,7 @@ class TestFullyShardShardPlacementFnMultiThread(FSDPTestMultiThread):
def world_size(self) -> int:
return 4
@unittest.skipIf(not TEST_CUDA, "no cuda")
@skip_if_lt_x_gpu(1)
def test_shard_placement_fn_contiguous_params_grads(self):
dim = 4
model = MLP(dim=dim)
@ -825,7 +853,7 @@ class TestFullyShardShardPlacementFnMultiThread(FSDPTestMultiThread):
self.assertTrue(param.is_contiguous())
self.assertTrue(param.to_local().is_contiguous())
inp = torch.randn((2, dim), device="cuda")
inp = torch.randn((2, dim), device=device_type.type)
model(inp).sum().backward()
for param in model.parameters():
@ -838,7 +866,7 @@ class TestFullyShardShardPlacementFnMultiThread(FSDPTestMultiThread):
class TestFullyShardSharedParams(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_train_parity_with_shared_params(self):
@ -858,8 +886,11 @@ class TestFullyShardSharedParams(FSDPTest):
torch.manual_seed(42)
model_args = ModelArgs(n_layers=3, dropout_p=0.0, weight_tying=True)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, device_ids=[self.rank])
ref_model = copy.deepcopy(model).to(device_type)
replicate(
ref_model,
device_ids=[self.rank],
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
@ -871,7 +902,9 @@ class TestFullyShardSharedParams(FSDPTest):
torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(10):
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
inp = torch.randint(
0, model_args.vocab_size, (2, 16), device=device_type.type
)
losses: list[torch.Tensor] = []
for _model, _optim in ((ref_model, ref_optim), (model, optim)):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
@ -884,7 +917,7 @@ class TestFullyShardSharedParams(FSDPTest):
class TestFullyShardGradientAccumulation(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_gradient_accumulation(self):
@ -892,12 +925,14 @@ class TestFullyShardGradientAccumulation(FSDPTest):
Tests gradient accumulation with/without gradient reduction and
with/without resharding after backward.
"""
meshes = [init_device_mesh("cuda", (self.world_size,))] # always test FSDP
meshes = [
init_device_mesh(device_type.type, (self.world_size,))
] # always test FSDP
if self.world_size == 4: # test HSDP too if enough GPUs
shard_size, replicate_size = 2, 2
meshes.append(
init_device_mesh(
"cuda",
device_type.type,
(replicate_size, shard_size),
mesh_dim_names=("dp_replicate", "dp_shard"),
)
@ -951,7 +986,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
modules = [nn.Linear(lin_dim, lin_dim)]
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
model = nn.Sequential(*modules)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
fully_shard_fn = functools.partial(
fully_shard,
mesh=mesh,
@ -994,7 +1029,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1
set_backward_flags(model, is_last_microbatch)
inp = torch.randn(batch_size, lin_dim, device="cuda")
inp = torch.randn(batch_size, lin_dim, device=device_type.type)
losses: list[torch.Tensor] = []
for _model in (ref_model, model):
with CommDebugMode() as comm_mode:
@ -1083,7 +1118,7 @@ class TestFullyShardGradientAccumulation(FSDPTest):
torch.manual_seed(42)
model_args = ModelArgs(dropout_p=0.0)
model = Transformer(model_args)
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
for module in model.modules():
if isinstance(module, TransformerBlock):
@ -1096,7 +1131,10 @@ class TestFullyShardGradientAccumulation(FSDPTest):
torch.manual_seed(42 + self.rank + 1)
inps = [
torch.randint(
0, model_args.vocab_size, (local_batch_size, 16), device="cuda"
0,
model_args.vocab_size,
(local_batch_size, 16),
device=device_type.type,
)
for _ in range(num_microbatches)
]
@ -1136,14 +1174,14 @@ class TestFullyShardGradientAccumulation(FSDPTest):
class TestFullyShardNDTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(8, torch.cuda.device_count())
return min(8, torch.get_device_module(device_type).device_count())
def init_global_mesh(self) -> DeviceMesh:
# Prefer to test with >=8 GPUs, but for 2 GPUs, use 2-way TP
dp_size = 2 if self.world_size > 2 else 1
pp_size = 2 if self.world_size > 4 else 1
return init_device_mesh(
"cuda",
device_type.type,
(pp_size, dp_size, self.world_size // (dp_size * pp_size)),
mesh_dim_names=("pp", "dp", "tp"),
)
@ -1179,8 +1217,12 @@ class TestFullyShardNDTraining(FSDPTest):
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_model = copy.deepcopy(model).to(device_type)
replicate(
ref_model,
device_ids=[self.rank],
process_group=dp_pg,
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
model.parallelize(
tp_mesh,
@ -1191,7 +1233,7 @@ class TestFullyShardNDTraining(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
torch.manual_seed(42 + dp_pg.rank() + 1)
device = torch.device("cuda")
device = device_type
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
losses: list[torch.Tensor] = []
@ -1212,11 +1254,11 @@ class TestFullyShardNDTraining(FSDPTest):
class TestFullyShardHSDP3DTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(8, torch.cuda.device_count())
return min(8, torch.get_device_module(device_type).device_count())
def init_global_mesh(self) -> DeviceMesh:
return init_device_mesh(
"cuda",
device_type.type,
(2, 2, 2),
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
)
@ -1248,8 +1290,12 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
torch.manual_seed(42)
model = MLPStack(mlp_dim)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, device_ids=[self.rank], process_group=dp_pg)
ref_model = copy.deepcopy(model).to(device_type)
replicate(
ref_model,
device_ids=[self.rank],
process_group=dp_pg,
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=foreach)
model.parallelize(
tp_mesh,
@ -1266,7 +1312,7 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=foreach)
torch.manual_seed(42 + dp_pg.rank() + 1)
device = torch.device("cuda")
device = device_type
for iter_idx in range(10):
inp = torch.randn((8, mlp_dim), device=device)
losses: list[torch.Tensor] = []
@ -1289,14 +1335,14 @@ class TestFullyShardHSDP3DTraining(FSDPTest):
class TestFullyShardHSDPTraining(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
return min(4, torch.get_device_module(device_type).device_count())
@skip_if_lt_x_gpu(2)
def test_train_parity_hsdp(self):
shard_size = 2 if self.world_size > 2 else 1
replicate_size = self.world_size // shard_size
global_mesh = init_device_mesh(
"cuda",
device_type.type,
(replicate_size, shard_size),
mesh_dim_names=("dp_replicate", "dp_shard"),
)
@ -1325,8 +1371,11 @@ class TestFullyShardHSDPTraining(FSDPTest):
MLP(mlp_dim),
MLP(mlp_dim, dim_multiplier=3),
)
ref_model = copy.deepcopy(model).cuda()
replicate(ref_model, device_ids=[self.rank])
ref_model = copy.deepcopy(model).to(device_type)
replicate(
ref_model,
device_ids=[self.rank],
)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
for mlp in model:
if use_activation_checkpointing:
@ -1340,7 +1389,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
check_sharded_parity(self, ref_model, model)
torch.manual_seed(42 + self.rank + 1)
device = torch.device("cuda")
device = device_type
num_microbatches = 3
for iter_idx in range(5):
for microbatch_idx in range(num_microbatches):
@ -1363,7 +1412,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
class TestFullyShardCustomForwardMethod(FSDPTest):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)
return min(torch.get_device_module(device_type).device_count(), 2)
@skip_if_lt_x_gpu(2)
def test_register_fsdp_forward_method(self):
@ -1392,14 +1441,14 @@ class TestFullyShardCustomForwardMethod(FSDPTest):
torch.manual_seed(42)
model = Model()
ref_model = copy.deepcopy(model).cuda()
ref_model = copy.deepcopy(model).to(device_type)
fully_shard(model.vit)
fully_shard(model.projector)
fully_shard(model)
register_fsdp_forward_method(model.vit, "forward_features")
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn(4, 3, 224, 224, device="cuda")
inp = torch.randn(4, 3, 224, 224, device=device_type.type)
ref_loss = ref_model(inp).sum()
loss = model(inp).sum()
self.assertEqual(ref_loss, loss)

View File

@ -375,7 +375,9 @@ class AOTAutogradCacheTests(InductorTestCase):
"Allow in graph produces an unserializable cache artifact"
)
with inductor_config.patch("unsafe_marked_cacheable_functions", [fn_name]):
with inductor_config.patch(
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
):
fn(*args)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
@ -390,6 +392,36 @@ class AOTAutogradCacheTests(InductorTestCase):
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
self._clear_dynamo_and_codecache()
with inductor_config.patch(
"unsafe_marked_cacheable_functions", {fn_name: "key2"}
):
fn(*args)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
self._clear_dynamo_and_codecache()
fn(*args)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
# On second try with same key, it should hit once more
with inductor_config.patch(
"unsafe_marked_cacheable_functions", {fn_name: "key1"}
):
self._clear_dynamo_and_codecache()
fn(*args)
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 3)
self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
@inductor_config.patch("fx_graph_remote_cache", False)
@inductor_config.patch("fx_graph_cache", False)
@functorch_config.patch({"enable_autograd_cache": True})

View File

@ -1873,7 +1873,7 @@ def forward(self, x, y):
return x + x
def false_fn(x):
return x[:2]
return x[:2].clone()
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
@ -1883,7 +1883,7 @@ def forward(self, x, y):
return x + x
def false_fn(x):
return x[:2]
return x[:2].clone()
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
@ -1924,7 +1924,8 @@ def forward(self, l_x_):
def forward(self, l_x_):
l_x__1 = l_x_
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
return (getitem,)""",
clone = getitem.clone(); getitem = None
return (clone,)""",
)
# We could successfully export branches that return different sizes
torch._dynamo.export(mod)(torch.randn(3, 2))
@ -3302,7 +3303,12 @@ def forward(self, x):
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
def f_branch_return_multiple_tensors(pred, x, y):
return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
return cond(
pred,
lambda x: (x.clone(), x.clone()),
lambda x: (x.clone(), x.clone()),
[y],
)
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
gm, _ = torch._dynamo.export(
@ -3324,10 +3330,10 @@ def forward(self, x):
def test_cond_raise_user_error_on_mismatch_return_length(self):
def true_fn(x):
return x
return x.clone()
def false_fn(x):
return (x, x)
return (x.clone(), x.clone())
def f_mismatch_return_length(x):
return cond(torch.tensor(100), true_fn, false_fn, [x])

View File

@ -170,8 +170,8 @@ class FunctionTests(torch._dynamo.test_case.TestCase):
in warning_message
):
break
else:
self.assertTrue(False, "Expected warning about lru_cache not found")
else:
self.assertTrue(False, "Expected warning about lru_cache not found")
@make_test
def test_add(a, b):

View File

@ -1791,7 +1791,13 @@ def forward(self, child : torch.Tensor):
def test_map_pytree_return(self):
def _construct_pytree(a):
return (a, [[[a]]], a, (a, (a,), a), {"a": a})
return (
a.clone(),
[[[a.clone()]]],
a.clone(),
(a.clone(), (a.clone(),), a.clone()),
{"a": a.clone()},
)
def f(x):
def inner_f(xs):
@ -1823,7 +1829,14 @@ def forward(self, L_x_ : torch.Tensor):
body_graph,
"""\
def forward(self, child : torch.Tensor):
return (child, child, child, child, child, child, child)""",
child_1 = child.clone()
child_2 = child.clone()
child_3 = child.clone()
child_4 = child.clone()
child_5 = child.clone()
child_6 = child.clone()
child_7 = child.clone(); child = None
return (child_1, child_2, child_3, child_4, child_5, child_6, child_7)""",
)
def test_map_kwargs(self):
@ -6902,7 +6915,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
def test(pred, x):
def true_fn(x):
return x
return x.clone()
def false_fn(x):
return -x
@ -6926,7 +6939,7 @@ class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase):
def test(pred, mode, x):
def true_fn(x):
return x
return x.clone()
def false_fn(x):
return -x

View File

@ -5931,7 +5931,7 @@ utils_device.CURRENT_DEVICE == None""".split(
from functorch.experimental.control_flow import cond
def true_fn(x):
return x
return x.clone()
def false_fn(x):
return x.sin()

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: dynamo"]
from unittest import expectedFailure
from unittest.mock import patch
import torch
@ -394,7 +393,6 @@ class RecompileTests(torch._dynamo.test_case.TestCase):
self.assertEqual(counter.frame_count, 2) # not three or four!
@expectedFailure # TODO(laithsakka, pianpwk): handle guard_or_false before oblivious hint fallback
@torch._dynamo.config.patch(automatic_dynamic_shapes_mark_as="oblivious")
def test_automatic_dynamic_shapes_mark_as_oblivious(self):
counter = torch._dynamo.testing.CompileCounter()

View File

@ -9,7 +9,7 @@ append_cxx_flag_if_supported("-Wno-unused-private-field" CMAKE_CXX_FLAGS)
# Generate unboxing kernels
set(GEN_COMMAND
"${Python_EXECUTABLE}" -m torchgen.gen_executorch
Python::Interpreter -m torchgen.gen_executorch
--source-path=${TEST_ROOT}
--install-dir=${OUTPUT_DIRECTORY}
--tags-path=${TORCH_ROOT}/aten/src/ATen/native/tags.yaml
@ -58,11 +58,7 @@ add_executable(test_edge_op_registration
target_compile_definitions(test_edge_op_registration PRIVATE USE_GTEST)
set(TEST_DEPENDENCIES gtest unbox_lib)
target_link_libraries(test_edge_op_registration PRIVATE
${TEST_DEPENDENCIES}
)
target_link_libraries(test_edge_op_registration PRIVATE gtest_main unbox_lib)
if((CMAKE_CXX_COMPILER_ID MATCHES "AppleClang") OR (APPLE AND CMAKE_CXX_COMPILER_ID MATCHES "Clang"))
target_link_options(test_edge_op_registration PRIVATE
"-Wl,-force_load,$<TARGET_FILE:unbox_lib>"

View File

@ -323,7 +323,7 @@ class TestDraftExport(TestCase):
self.assertEqual(
report.failures[0].failure_type, FailureType.DATA_DEPENDENT_ERROR
)
self.assertEqual(report.failures[0].data["expr"], "Eq(9380*u1, 0)")
self.assertEqual(report.failures[0].data["expr"], "Eq(Mod(10, 2*u1), 0)")
def test_dedup_data_dependent_failure(self):
class M(torch.nn.Module):

View File

@ -7604,20 +7604,25 @@ def forward(self, b_a_buffer, x):
self.assertTrue(torch.allclose(ep.module()(xs), module_out))
@requires_cuda
@testing.expectedFailureCppRuntime
def test_export_associative_scan_lifted_buffers(self):
device = torch.device("cuda")
combine_mode = "pointwise"
class A(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.buffer = torch.nn.Buffer(torch.ones(3, 2, device=device))
def forward(self):
return self.buffer.cos()
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer(
"buf", torch.ones(3, 2, device=device), persistent=False
)
self.a = A()
def combine_fn(self, x, y):
return x + y * self.buf
return (x + y) * self.a()
def forward(self, x):
return associative_scan(

View File

@ -4572,17 +4572,17 @@ class <lambda>(torch.nn.Module):
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = None
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); sum_1 = None
body_graph_1 = self.body_graph_1
map_impl_1 = torch.ops.higher_order.map_impl(body_graph_1, [cos], [arg1_1]); body_graph_1 = cos = arg1_1 = None
getitem_1: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
getitem_5: "f32[2, 2]" = map_impl_1[0]; map_impl_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_5); getitem_5 = None
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
return (add_1,)
@ -4635,9 +4635,9 @@ class <lambda>(torch.nn.Module):
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [cos], [arg1_1]); body_graph_0 = arg1_1 = None
getitem: "f32[2, 2]" = map_impl[0]; map_impl = None
getitem_2: "f32[2, 2]" = map_impl[0]; map_impl = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add,)

Some files were not shown because too many files have changed in this diff Show More