Compare commits

..

42 Commits

Author SHA1 Message Date
3ddec713b8 Revert "[cuDNN][Quantization] Don't print when plan finalization fails in cuDNN quantization backend (#128177)"
This reverts commit cac7a22b92478d897488688010e562b7bd36b97f.

Reverted https://github.com/pytorch/pytorch/pull/128177 on behalf of https://github.com/clee2000 due to broke test/test_quantization.py::TestQuantizedLinear::test_qlinear_cudnn on sm86 tests cac7a22b92 https://github.com/pytorch/pytorch/actions/runs/9470648757/job/26100448913.  Probably a landrace, test ran on the PR and succeed ([comment](https://github.com/pytorch/pytorch/pull/128177#issuecomment-2161977110))
2024-06-12 02:20:15 +00:00
85eeb90d2c [dynamo] Fix graph breaks related to HF ModelOutput (#127780)
Fixes https://github.com/pytorch/pytorch/issues/126028 and https://github.com/pytorch/pytorch/issues/126027.

Changes:
- Support building `CustomizedDictVariable` in` VariableBuilder` (but only for HF `ModelOutput` subclasses)
- Remove `DataClassVariable` since it's not really being used anywhere (`CustomizedDictVariable` can be used instead)
- Support side effects for `CustomizedDictVariable`
- Allow `NO_HASATTR` leaf guard on `DictSubclassGuardManager`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127780
Approved by: https://github.com/jansel, https://github.com/anijain2305
2024-06-12 02:16:24 +00:00
7f6daf289b [inductor] parallel compile: set LD_LIBRARY_PATH for sub-processes in internal (#128376)
Test Plan: `TORCHINDUCTOR_WORKER_START=subprocess TORCHINDUCTOR_COMPILE_THREADS=16 buck run mode/opt scripts/slarsen/torch_compile:run`

Differential Revision: D58371264

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128376
Approved by: https://github.com/eellison
2024-06-12 01:55:53 +00:00
3d55d84ec2 [Fix] Check tensor dtype before using torch.allclose in _trace log (#128438)
#### Issue
`torch.allclose` errors out during logging due to different dtypes.

#### Test
* `pytest test/test_jit.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128438
Approved by: https://github.com/angelayi
2024-06-12 01:52:09 +00:00
bb2a995529 Back out "[Dynamo] Treat integers stored on nn.Modules as dynamic (#126466)" (#128432)
Summary:
Original commit changeset: c7d2e6b13922

Original Phabricator Diff: D57618942

Differential Revision: D58383241

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128432
Approved by: https://github.com/ezyang, https://github.com/Yuzhen11
2024-06-12 01:34:32 +00:00
cyy
9538bf4e7c [2/N] Remove inclusion of c10/util/string_utils.h (#128372)
Follows  #128300.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128372
Approved by: https://github.com/aaronenyeshi
2024-06-12 01:18:20 +00:00
cyy
219da29dfd [7/N] Remove unused functions (#128407)
Follows  #128309
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128407
Approved by: https://github.com/ezyang
2024-06-12 01:10:33 +00:00
cyy
fb013ecb24 Remove unused private List::ptr_to_first_element (#128405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128405
Approved by: https://github.com/ezyang
2024-06-12 01:07:14 +00:00
6af4c6acad Migrate test to internal base class, fixes (#128367)
Summary:
## Remove etc deps
converted tests to non-etcd based rdzv handler so that tests don't have dependency on etcd server

## Adopt pytorch test convetions
- test starts with `test_TESTS.py`
- Test base class is torch.testing._internal.common_utils.TestCase
- include __main__  handler

## reduce test timing (used to take > 300 seconds):

3.05s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_init_method_env_with_torchelastic
2.59s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_init_method_tcp_with_torchelastic
2.33s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_worker_raise_exception
2.33s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_run_path
2.30s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_launch_auto_configurations
2.24s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_is_torchelastic_launched_with_logs_spec_defined
2.24s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_is_torchelastic_launched
2.17s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_multiple_agents
2.12s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic
2.08s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_gpu_launch_configurations
1.32s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_standalone
1.05s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_nproc_launch_number_configurations
1.05s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_with_env_vars
1.05s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_python
1.05s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_python_caffe2_bc
1.04s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_bash
1.03s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_user_script_default_nproc
0.04s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_logs_logs_spec_entrypoint_must_be_defined
0.01s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_elastic_agent_raise_exception
0.01s call     test/distributed/launcher/run_test.py::ElasticLaunchTest::test_launch_shutdown

Test Plan: pytest --durations=0  test/distributed/launcher/run_test.py

Differential Revision: D58388182

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128367
Approved by: https://github.com/d4l3k
2024-06-12 01:03:40 +00:00
786c24a4cd [inductor] Always realize sigmoid for CPU (#128339)
Summary: Currently the cpu backend prefers to always realize exp because it's a heavy op on CPU. For the same reason, we need to realize sigmoid as well. This solves a problem in llama2 inference where exp was repeated in an inner loop for many times.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128339
Approved by: https://github.com/eellison, https://github.com/helloguo, https://github.com/jansel, https://github.com/jgong5, https://github.com/peterbell10
2024-06-12 00:46:33 +00:00
5d8c7f39d4 Revert "Introduce int_oo (#127693)"
This reverts commit 9cab5987bdeb66df8efbc581b3469bfe300e168c.

Reverted https://github.com/pytorch/pytorch/pull/127693 on behalf of https://github.com/clee2000 due to sorry executorch CI is a bit weird regarding pins, I'll make a chat with mergen with the choices of what to do and how it'll affect executorch CI, reverting for now to prevent more divergences in the meantime ([comment](https://github.com/pytorch/pytorch/pull/127693#issuecomment-2161775400))
2024-06-11 23:36:08 +00:00
c9c1fed065 Revert "Flip default value for mypy disallow_untyped_defs [10+2/11] (#128374)"
This reverts commit c13e03c87428b986972a48d8fc78dbffc2579f63.

Reverted https://github.com/pytorch/pytorch/pull/128374 on behalf of https://github.com/clee2000 due to sorry I need to revert this in order to revert something else, to remerge, just rebase and fix the merge conflict ([comment](https://github.com/pytorch/pytorch/pull/128374#issuecomment-2161772864))
2024-06-11 23:34:03 +00:00
94fea82d66 init sub comment (#128082)
Fixes #127905

### Description

Add docstring to torch/onnx/symbolic_opset9.py:sigmoid function

### Checklist
- [x] The issue that is being fixed is referred in the description
- [x] Only one issue is addressed in this pull request
- [x] Labels from the issue that this PR is fixing are added to this pull request
- [x] No unnecessary issues are included into this pull request

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128082
Approved by: https://github.com/titaiwangms
2024-06-11 22:42:35 +00:00
447173198b Add docstring for the torch.fx.operator_schemas.create_type_hint func… (#128139)
Fixes: #127916

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128139
Approved by: https://github.com/SherlockNoMad
2024-06-11 22:42:11 +00:00
b79d056e76 [export] FIx unflattener for preserving modules containing unused inputs (#128260)
Currently unflattener fails if the module its preserving the module signature for contains unused inputs/outputs.

This also fixes unflattener issues in D57829276.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128260
Approved by: https://github.com/pianpwk
2024-06-11 22:32:08 +00:00
eb567b1f40 Pass params to dump_nccl_trace_pickle (#128307)
Summary:
Pass parameters from request to dump_nccl_trace_pickle handler.
The supported parameters + value are all lowercase.
includecollectives={true, false}
includestacktraces={true, false}
onlyactive={true, false}

Example post is:
/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true

Test Plan:
unit tests

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128307
Approved by: https://github.com/d4l3k
ghstack dependencies: #128191
2024-06-11 22:28:53 +00:00
1dd2431f86 [Test] Add test for only_active flag (#128191)
Summary:
Add a unit test for the only_active flag to _dump_nccl_trace API call.
With this flag, we only expect active records to be returned.

Test Plan:
Unit test.

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128191
Approved by: https://github.com/d4l3k
2024-06-11 22:26:01 +00:00
5fcb5f0c8b init reshape_from_tensor_shape comment (#128171)
Fixes #127897

### Description
Add docstring to torch/onnx/symbolic_opset9.py:sigmoid function

### Checklist
- [x] The issue that is being fixed is referred in the description
- [x] Only one issue is addressed in this pull request
- [x] Labels from the issue that this PR is fixing are added to this pull request
- [x] No unnecessary issues are included into this pull request

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128171
Approved by: https://github.com/titaiwangms
2024-06-11 21:56:33 +00:00
a55d0d9718 Fix side effect pruning (#128028)
Summary:
The previous side effect pruning algorithm would keep many dead cell
variables alive. For example, in
https://github.com/pytorch/pytorch/issues/125078, the compiled function
has one return but there were three in the Dynamo graph due to two
dead cell variables not being pruned away.

This PR adds a corrected algorithm. "new cell variables" are alive if
they can be reached from one of the following:
1. any of the tx.symbolic_locals or tx.stack (that is, if they are
   involved in a return from the function or intermediate variable
   during a graph break). Example: an alive NestedUserFunctionVariable
2. "mutations to pre-existing objects". Example: appending a
   NestedUserFunctionVariable to a global list

The new algorithm reflects this, but please let me know if there are
more cases to handle.

Test Plan:
- existing tests (afaict, test/dynamo/test_python_autograd is the best
  SideEffects test case we have)
- see in test/dynamo/test_higher_order_ops that the expecttests changed
  -- the functorch dynamo graphs no longer return dead cellvars.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128028
Approved by: https://github.com/jansel
2024-06-11 21:40:48 +00:00
8c1247cffb [BE] Fixed CPU autocast warning (#127774)
This PR fixes
```
/data/users/andgu/pytorch/torch/utils/checkpoint.py:1398: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127774
Approved by: https://github.com/soulitzer, https://github.com/Skylion007, https://github.com/tianyu-l
2024-06-11 21:33:35 +00:00
70a1e85718 [Traceable FSDP2] Use custom ops for AllGather copy-in / copy-out and ReduceScatter copy-in (#127856)
Making these operations into custom ops helps Inductor identify these ops and enforce the FSDP communication op ordering.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127856
Approved by: https://github.com/awgu
2024-06-11 20:15:03 +00:00
adb699189b Revert "[RELAND][dynamo][nn-modules] Trace through nn.Module dunder methods for UnspecializedNNModule (#126578)"
This reverts commit b2d602306a9eb19e30328cbaee941c874f8148a9.

Reverted https://github.com/pytorch/pytorch/pull/126578 on behalf of https://github.com/clee2000 due to failed internal test D58394084.  Author has forward fix but includes external changes so reverting is a bit easier to coordinate ([comment](https://github.com/pytorch/pytorch/pull/126578#issuecomment-2161481839))
2024-06-11 19:41:41 +00:00
eqy
45dccfddcd [cuDNN][SDPA] Support different key, value dimension in cuDNN SDPA (#128350)
CC @vedaanta-nvidia @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128350
Approved by: https://github.com/Skylion007
2024-06-11 19:22:21 +00:00
3e09123797 Enable UFMT on test_nestedtensor.py (#128359)
split it into two PRs since it is more than 2k lines of change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128359
Approved by: https://github.com/davidberard98
2024-06-11 19:14:04 +00:00
61f922c2ca Fix 'get_real_value' on placeholder nodes (#127698)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127698
Approved by: https://github.com/jansel
ghstack dependencies: #127695, #127696
2024-06-11 18:57:25 +00:00
984b1a8c35 Fix 'get_attr' call in dynamo 'run_node' (#127696)
Fixes #124858

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127696
Approved by: https://github.com/jansel
ghstack dependencies: #127695
2024-06-11 18:57:25 +00:00
205410cb44 add xpu to torch.tensors (#127280)
As support for Intel GPU has been upstreamed, this PR is to add the XPU-related contents to torch.tensors doc.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127280
Approved by: https://github.com/svekars
2024-06-11 18:13:01 +00:00
cac7a22b92 [cuDNN][Quantization] Don't print when plan finalization fails in cuDNN quantization backend (#128177)
Similar in spirit to #125790, hopefully addresses failures seen for cuDNN 9.1 upgrade: #https://github.com/pytorch/pytorch/pull/128166

CC @nWEIdia @atalman

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128177
Approved by: https://github.com/nWEIdia, https://github.com/Skylion007
2024-06-11 18:09:25 +00:00
8a09940a54 [inductor] fix compile time regression by caching get_gpu_type (#128363)
We observed signficant compile time regression in torchtitan when turning
on 2D parallel + torch.compile recently. So I decided to get a deeper
understanding why.

It turns out this is affecting **all the trainings** that have functional collectives
captured in the graph, not only 2D parallel (2D parallel was just the
job that happen to have collectives captured in the TP region).

The root cause is because when doing inductor lowering, we are calling
the comm analysis pass to get a estimated collective time for each
collective node in the graph, for each call to check the collective
node, we are calling `get_gpu_type()`, which under the hood calls a
`torch.utils.collect_env.run` to get the GPU info. However, this call is
super expensive! The reason is that this call effectively spawns a new
process and call `nvidia-smi` to get the GPU info, so the cost is **linear**
to the number of collective nodes in the graph.

see https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py#L75

The fix is to add a lru cache to the function, so that we only call this
once and reuse the cached results afterwards

torchtitan benchmark shows:
* before this fix: 2D parallel + fp8 compile time: 6min +
* after this fix: 2D parallel + fp8 compile time: 2min 48s (more than 100% improvement)

There're more room to improve the compile time, but this PR is trying to fix the biggest regression I found so far.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128363
Approved by: https://github.com/yf225
2024-06-11 18:02:13 +00:00
1d233b8f50 Revert "Make nn.Module state_dict load_state_dict pre-hook and state_dict post hook public (#126704)"
This reverts commit c38b3381a12a0ec033dd417827c530c4474b8165.

Reverted https://github.com/pytorch/pytorch/pull/126704 on behalf of https://github.com/clee2000 due to broke internal typecheck D58394110 (which probably means the code wouldn't work either but I guess it didn't run on the diff). Probably an easy fix? ([comment](https://github.com/pytorch/pytorch/pull/126704#issuecomment-2161299193))
2024-06-11 17:45:20 +00:00
491c4a5dcb Revert "Make sure #126704 is BC for torch.save-ed nn.Module (#128344)"
This reverts commit 841d87177a900c2bbd59b6589165189141c4e8bb.

Reverted https://github.com/pytorch/pytorch/pull/128344 on behalf of https://github.com/clee2000 due to broke internal typecheck D58394110 (which probably means the code wouldn't work either but I guess it didn't run on the diff). Probably an easy fix? ([comment](https://github.com/pytorch/pytorch/pull/126704#issuecomment-2161299193))
2024-06-11 17:45:20 +00:00
4345d98663 [dynamo] Fix for #127696 (#128358)
Test Plan:
`buck2 test @//mode/dev-nosan //executorch/exir/backend/...`
https://www.internalfb.com/intern/testinfra/testrun/12666373989243932

Differential Revision: D58384518

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128358
Approved by: https://github.com/ydwu4
2024-06-11 16:43:15 +00:00
a838e90964 Add Intel Gaudi device/HPU to auto load in instantiate_device_type_tests (#126970)
### Motivation
Intel Gaudi accelerator (device name hpu) is seen to have good pass rate with the pytorch framework UTs , however being an out-of-tree device, we face challenges in adapting the device to natively run the existing pytorch UTs under pytorch/test. The UTs however is a good indicator of the device stack health and as such we run them regularly with adaptations.
Although we can add Gaudi/HPU device to generate the device specific tests using the TORCH_TEST_DEVICES environment variable, we miss out on lot of features such as executing for specific dtypes, skipping and overriding opInfo. With significant changes introduced every Pytorch release maintaining these adaptations become difficult and time consuming.
Hence with this PR  we introduce Gaudi device in common_device_type framework, so that the tests are instantiated for Gaudi when the library is loaded.
The eventual goal is to introduce Gaudi out-of-tree support as equivalent to in-tree devices

### Changes
Add HPUTestBase of type DeviceTypeTestBase specifying appropriate attributes for Gaudi/HPU.
Include code to check if  intel Gaudi Software library is loaded and if so, add the device to the list of devices considered for instantiation of device type tests

### Additional Context
please refer the following RFC : https://github.com/pytorch/rfcs/pull/63/

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126970
Approved by: https://github.com/albanD
2024-06-11 16:35:17 +00:00
29081059b6 [Static Runtime] Fix & run gen_static_runtime_ops (#128299)
gen_static_runtime_ops hasn't been updated in a while. In preparation for https://github.com/pytorch/pytorch/pull/127675 in which I need to re-run the codegen step for cumprod, I want to land these changes beforehand in case there are any other issues that arise.

I added a number of ops to the blocklist:
```
+        "_nested_tensor_storage_offsets",
+        "_nested_get_values",  # no CPU backend
+        "_nested_get_values_copy",  # no CPU backend
+        "_nested_view_from_jagged",  # testing needs to be patched
+        "_nested_view_from_jagged_copy",  # testing needs to be patched
+        "_nested_view_from_buffer",  # testing needs to be patched
+        "_nested_view_from_buffer_copy",  # testing needs to be patched
+        "_int_mm",  # testing needs to be patched
+        "_to_sparse_csc",  # testing needs to be patched
+        "_to_sparse_csr",  # testing needs to be patched
+        "segment_reduce",  # testing needs to be patched
```

Most of these are added just because testing doesn't work right now.

Additionally, a few `fft` ops seem to have been removed from native_functions.yaml; I'm guessing it's unlikely FFT would have been used in many real models though.

Differential Revision: [D58329403](https://our.internmc.facebook.com/intern/diff/D58329403/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128299
Approved by: https://github.com/YuqingJ
2024-06-11 16:27:39 +00:00
f8c45996d5 [MPS] Make erfinv compilable for bfloat16 (#128375)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128375
Approved by: https://github.com/Skylion007
ghstack dependencies: #128373
2024-06-11 16:04:11 +00:00
c13e03c874 Flip default value for mypy disallow_untyped_defs [10+2/11] (#128374)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128374
Approved by: https://github.com/Skylion007
2024-06-11 15:58:28 +00:00
053930e194 [MPS][BE] Remove code duplication (#128373)
Use `scalarToMetalTypeString` instead of `getMetalType`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128373
Approved by: https://github.com/Skylion007
2024-06-11 15:58:04 +00:00
9a38cae299 [AOTI] Switch to use shim v2 (#127674)
Differential Revision: D56709309

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127674
Approved by: https://github.com/desertfire
2024-06-11 15:01:25 +00:00
55901fb3da [fx] Preserve Fx graph node order in partitioner across runs (#115621)
Fixes #ISSUE_NUMBER
partitioner generates different graph in recompilation on each run
Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115621
Approved by: https://github.com/ezyang
2024-06-11 14:04:52 +00:00
fc77fdca6f [guard_size_oblivious] Add gso ExpandUtils:_sym_to (#128224)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128224
Approved by: https://github.com/ezyang
2024-06-11 14:01:34 +00:00
648625b230 Make TraceUtils.h to be device-agnostic (#126969)
Some features of third-party devices depend on TraceUtils.h, so some of the CUDA code was removed and split into NCCLUtils files.

In addition, some common functions still remain in TraceUtils.h since I'm not sure if other devices will use them later.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126969
Approved by: https://github.com/c-p-i-o
2024-06-11 08:38:07 +00:00
207c2248a8 [inductor] Fix lowering full with SymBool value (#128213)
Fixes #128161, fixes #128095

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128213
Approved by: https://github.com/lezcano
2024-06-11 08:33:35 +00:00
169 changed files with 3385 additions and 3215 deletions

View File

@ -1099,7 +1099,6 @@ exclude_patterns = [
'test/test_namedtuple_return_api.py',
'test/test_native_functions.py',
'test/test_native_mha.py',
'test/test_nestedtensor.py',
'test/test_nn.py',
'test/test_out_dtype_op.py',
'test/test_overrides.py',

View File

@ -462,7 +462,7 @@ inline Tensor _sum_to(
reduce_dims.push_back(i);
}
for (int64_t i = leading_dims; i < static_cast<int64_t>(sizes.size()); ++i) {
if (shape[i - leading_dims] == 1 &&
if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(shape[i - leading_dims], 1)) &&
TORCH_GUARD_SIZE_OBLIVIOUS(sym_ne(sizes[i], 1))) {
reduce_dims.push_back(i);
}

View File

@ -478,8 +478,6 @@ namespace impl {
// (maybe except for some internal prim ops).
using GenericList = List<IValue>;
const IValue* ptr_to_first_element(const GenericList& list);
}
}

View File

@ -350,11 +350,4 @@ void List<T>::unsafeSetElementType(TypePtr t) {
impl_->elementType = std::move(t);
}
namespace impl {
inline const IValue* ptr_to_first_element(const GenericList& list) {
return &list.impl_->list[0];
}
}
}

View File

@ -1195,15 +1195,6 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> ho
#undef REPR
}
static Tensor istft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool center, const bool normalized, const optional<bool> onesidedOpt,
const optional<int64_t> lengthOpt) {
return at::native::istft(
self, n_fft, hop_lengthOpt, win_lengthOpt, window, center, normalized,
onesidedOpt, lengthOpt, /*return_complex=*/false);
}
void _fft_fill_with_conjugate_symmetry_(const Tensor& input, IntArrayRef dim_) {
const auto input_sizes = input.sizes();
const auto input_strides = input.strides();

View File

@ -210,7 +210,6 @@
#include <ATen/ops/zeros_native.h>
#endif
#include <c10/util/StringUtil.h>
#include <algorithm>
#include <cstdint>
#include <utility>

View File

@ -13,7 +13,8 @@ void run_cudnn_SDP_fprop(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool isTraining,
bool is_causal,
@ -34,7 +35,8 @@ void run_cudnn_SDP_bprop(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool is_causal,
float dropout_probability,
@ -128,7 +130,8 @@ struct MHAParams {
int64_t h;
int64_t s_q;
int64_t s_kv;
int64_t d;
int64_t d_qk;
int64_t d_v;
double dropout_probability;
bool is_causal;
bool return_softmaxstats;
@ -140,7 +143,8 @@ void setMHAParams(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_qk,
int64_t d_v,
const Tensor& q,
const Tensor& k,
const Tensor& v,
@ -155,7 +159,8 @@ void setMHAParams(
}
params.b = b;
params.h = h;
params.d = d;
params.d_qk = d_qk;
params.d_v = d_v;
params.s_q = s_q;
params.s_kv = s_kv;
params.dropout_probability = dropout_probability;
@ -193,7 +198,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_qk,
int64_t d_v,
const Tensor& q,
const Tensor& k,
const Tensor& v,
@ -206,7 +212,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
h,
s_q,
s_kv,
d,
d_qk,
d_v,
q,
k,
v,
@ -249,7 +256,8 @@ auto build_graph_and_tensors(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool return_softmaxstats,
bool is_causal,
@ -383,7 +391,8 @@ auto build_graph_and_tensors_backward(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool is_causal,
float dropout_probability,
@ -514,7 +523,8 @@ void run_cudnn_SDP_fprop(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool return_softmaxstats,
bool is_causal,
@ -528,7 +538,7 @@ void run_cudnn_SDP_fprop(
Tensor& dropoutoffset) {
cudnnHandle_t handle = getCudnnHandle();
o = at::empty_strided(
{b, h, s_q, d}, {s_q * h * d, d, h * d, 1}, q.options());
{b, h, s_q, d_v}, {s_q * h * d_v, d_v, h * d_v, 1}, q.options());
if (return_softmaxstats) {
// TODO(eqy): verify that this is correct
softmaxstats = at::empty({b, h, s_q}, q.options().dtype(kFloat));
@ -539,7 +549,8 @@ void run_cudnn_SDP_fprop(
h,
s_q,
s_kv,
d,
d_qk,
d_v,
q,
k,
v,
@ -556,7 +567,8 @@ void run_cudnn_SDP_fprop(
h,
s_q,
s_kv,
d,
d_qk,
d_v,
scaling_factor,
return_softmaxstats,
is_causal,
@ -599,7 +611,8 @@ void run_cudnn_SDP_bprop(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool is_causal,
float dropout_probability,
@ -623,7 +636,18 @@ void run_cudnn_SDP_bprop(
}
cudnnHandle_t handle = getCudnnHandle();
auto key = MHACacheKeyWrapper(
b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true);
b,
h,
s_q,
s_kv,
d_qk,
d_v,
q,
k,
v,
dropout_probability,
is_causal,
true);
auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key);
graph_and_tensors_backward graph_and_tensors_backward_values;
if (graph_and_tensors_backward_ptr) {
@ -634,7 +658,8 @@ void run_cudnn_SDP_bprop(
h,
s_q,
s_kv,
d,
d_qk,
d_v,
scaling_factor,
is_causal,
dropout_probability,
@ -684,5 +709,4 @@ void run_cudnn_SDP_bprop(
} // namespace native
} // namespace at
#endif

View File

@ -9,7 +9,8 @@ void run_cudnn_SDP_fprop(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_k,
int64_t d_v,
float scaling_factor,
bool isTraining,
bool is_causal,
@ -27,7 +28,8 @@ void run_cudnn_SDP_bprop(
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
int64_t d_k,
int64_t d_v,
float scaling_factor,
bool is_causal,
float dropout_probability,

View File

@ -18,26 +18,21 @@ kernel void erfinv_mps_kernel( device {0} *output [[buffer(0)]],
/* coefficients in rational expansion */
float y_abs = abs(y);
if(y_abs > 1.0f){{
output[index] = NAN;
if (y_abs >= 1.0f) {{
output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y));
return;
}}
if(y_abs == 1.0f){{
output[index] = copysign(INFINITY, y);
return;
}}
if(y_abs <= 0.7f) {{
if (y_abs <= 0.7f) {{
z = y * y;
num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + 1.0f);
num = ((a[3] * z + a[2]) * z + a[1])*z + a[0];
dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f;
x = y * num / dem;
}}
else{{
}} else {{
z = sqrt(-1.0f*log((1.0-y_abs)/2.0));
num = ((c[3]*z + c[2])*z + c[1]) * z + c[0];
dem = (d[1]*z + d[0])*z + 1.0f;
num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
dem = (d[1] * z + d[0]) * z + 1.0f;
x = copysign(num, y) / dem;
}}
output[index] = x;
}})METAL";
output[index] = {0}(x);
}})METAL";

View File

@ -143,7 +143,7 @@ TORCH_IMPL_FUNC(leaky_relu_out_mps)(const Tensor& self, const Scalar& negative_s
Tensor output_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
@autoreleasepool {
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + to_string(negative_slope.to<double>());
string key = "leaky_relu" + getTensorsStringKey({self}) + ":" + std::to_string(negative_slope.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
@ -193,8 +193,8 @@ TORCH_IMPL_FUNC(leaky_relu_backward_out_mps)
Tensor output_ = at::empty_like(self, self.suggest_memory_format());
@autoreleasepool {
string key =
"leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" + to_string(negative_slope.to<double>());
string key = "leaky_relu_backward" + getTensorsStringKey({self, grad_output}) + ":" +
std::to_string(negative_slope.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
@ -242,7 +242,7 @@ TORCH_IMPL_FUNC(log_softmax_mps_out)
MPSStream* stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + to_string(dim);
string key = "log_softmax_mps_out" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
@ -285,7 +285,7 @@ TORCH_IMPL_FUNC(log_softmax_backward_mps_out)
MPSStream* stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + to_string(dim);
string key = "log_softmax_backward_mps_out:" + getMPSTypeString(grad_output) + ":" + std::to_string(dim);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_output));
MPSGraphTensor* outputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(output));
@ -539,8 +539,8 @@ TORCH_IMPL_FUNC(threshold_out_mps)
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + to_string(threshold.to<double>()) + ":" +
to_string(value.to<double>());
string key = "threshold_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(threshold.to<double>()) +
":" + std::to_string(value.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
@ -587,7 +587,7 @@ TORCH_IMPL_FUNC(threshold_backward_out_mps)
@autoreleasepool {
string key =
"threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + to_string(threshold.to<double>());
"threshold_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + std::to_string(threshold.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
@ -826,8 +826,8 @@ static void elu_variants_out_mps(const Tensor& self,
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = func_name + ":" + getTensorsStringKey({self}) + ":" + to_string(alpha.to<double>()) + ":" +
to_string(scale.to<double>()) + ":" + to_string(input_scale.to<double>());
string key = func_name + ":" + getTensorsStringKey({self}) + ":" + std::to_string(alpha.to<double>()) + ":" +
std::to_string(scale.to<double>()) + ":" + std::to_string(input_scale.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
@ -916,8 +916,8 @@ TORCH_IMPL_FUNC(elu_backward_out_mps)
@autoreleasepool {
string key = "elu_backward_out_mps:" + getTensorsStringKey({grad_output, self_or_result}) + ":" +
to_string(alpha.to<double>()) + ":" + to_string(scale.to<double>()) + ":" +
to_string(input_scale.to<double>()) + ":" + to_string(is_result);
std::to_string(alpha.to<double>()) + ":" + std::to_string(scale.to<double>()) + ":" +
std::to_string(input_scale.to<double>()) + ":" + std::to_string(is_result);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
@ -1010,7 +1010,7 @@ TORCH_IMPL_FUNC(glu_out_mps)(const Tensor& self, const int64_t dim, const Tensor
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + to_string(dim);
string key = "glu_out_mps" + getTensorsStringKey({self}) + ":" + std::to_string(dim);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
NSArray<MPSGraphTensor*>* outputTensorsArray = [mpsGraph splitTensor:inputTensor
@ -1052,7 +1052,7 @@ Tensor& glu_backward_mps_out(const Tensor& grad_output, const Tensor& self, cons
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + to_string(dim);
string key = "glu_backward_mps_out" + getTensorsStringKey({grad_output, self}) + ":" + std::to_string(dim);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
MPSGraphTensor* gradOutputTensor =
@ -1855,8 +1855,8 @@ Tensor& hardtanh_backward_out_mps(const Tensor& grad_output,
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" + to_string(min.to<double>()) +
":" + to_string(max.to<double>());
string key = "hardtanh_backward_out_mps:" + getTensorsStringKey({grad_output}) + ":" +
std::to_string(min.to<double>()) + ":" + std::to_string(max.to<double>());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);

View File

@ -136,8 +136,8 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
@autoreleasepool {
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" + to_string(beta_.toDouble()) +
":" + to_string(alpha_.toDouble());
string key = "addmv_out_mps_impl" + getTensorsStringKey({self, matMulVec}) + ":" +
std::to_string(beta_.toDouble()) + ":" + std::to_string(alpha_.toDouble());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* matMulVecTensor = mpsGraphRankedPlaceHolder(mpsGraph, matMulVec);
MPSGraphTensor* selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);

View File

@ -33,7 +33,7 @@ static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
};
@autoreleasepool {
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()));

View File

@ -193,24 +193,24 @@ static Tensor _mps_convolution_impl(const Tensor& input_t,
string bias_shape_key;
if (bias_defined) {
bias_shape_key = to_string(bias_shape[0]);
bias_shape_key = std::to_string(bias_shape[0]);
} else {
bias_shape_key = "nobias";
}
string key;
if (is3DConv) {
key = "mps_3d_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(stride[2]) +
":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) + ":" +
to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" + to_string(groups) +
":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" +
bias_shape_key;
key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
} else {
key = "mps_convolution:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + to_string(dilation[0]) +
":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
to_string(groups) + ":" + mem_format_key + mps::getTensorsStringKey({input_t, weight_t}) + ":" +
to_string(bias_defined) + ":" + bias_shape_key;
key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
}
MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
@ -388,16 +388,16 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
string key;
if (is3DConv) {
key = "mps_3d_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" + ":" +
to_string(stride[2]) + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(dilation[2]) +
":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" + to_string(padding[2]) + ":" +
to_string(groups) + ":" + mem_format_key + getTensorsStringKey({grad_output_t, weight_t}) + ":" +
string([ns_shape_key UTF8String]);
key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
} else {
key = "mps_convolution_backward_input:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
}
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
@ -547,15 +547,15 @@ static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
string key;
if (is3DConv) {
key = "mps_3d_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
to_string(stride[2]) + ":" + to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" +
to_string(dilation[2]) + ":" + to_string(padding[0]) + ":" + to_string(padding[1]) + ":" +
to_string(padding[2]) + ":" + to_string(groups) + ":" + mem_format_key +
key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
} else {
key = "mps_convolution_backward_weights:" + to_string(stride[0]) + ":" + to_string(stride[1]) + ":" +
to_string(dilation[0]) + ":" + to_string(dilation[1]) + ":" + to_string(padding[0]) + ":" +
to_string(padding[1]) + ":" + to_string(groups) + ":" + mem_format_key +
key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
}
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {

View File

@ -63,7 +63,7 @@ Tensor& random_mps_impl(Tensor& self,
@autoreleasepool {
string key = op_name + getTensorsStringKey({self, mean_opt.value_or(Tensor()), std_opt.value_or(Tensor())}) + ":" +
to_string(val1) + ":" + to_string(val2);
std::to_string(val1) + ":" + std::to_string(val2);
auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->stateTensor =
mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @(at::mps::detail::PHILOX_STATE_N) ]);
@ -469,7 +469,7 @@ static Tensor& multinomial_with_replacement_mps_kernel(const Tensor& self,
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + to_string(n_sample);
string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + std::to_string(n_sample);
auto cachedGraph = LookUpOrCreateCachedGraph<RandomCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSShape* prob_shape = getMPSShape(self_v);
newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]);

View File

@ -236,7 +236,7 @@ static std::tuple<Tensor, Tensor> _mps_linear_backward_weights(const Tensor& gra
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
string key = "mps_linear_backward_weights:" + to_string(bias_defined) + ":" +
string key = "mps_linear_backward_weights:" + std::to_string(bias_defined) + ":" +
getTensorsStringKey({input_reshaped, weight, grad_output_reshaped});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_reshaped);

View File

@ -229,8 +229,8 @@ static Tensor& addbmm_or_baddbmm_out_mps_impl(const Tensor& input,
@autoreleasepool {
string key = (opType == ADDBMM_OP_TYPE) ? ("addbmm_out_mps_impl") : ("baddbmm_out_mps_impl");
key += getTensorsStringKey({batch1, batch2, input}) + ":" + to_string(beta.toDouble()) + ":" +
to_string(alpha.toDouble());
key += getTensorsStringKey({batch1, batch2, input}) + ":" + std::to_string(beta.toDouble()) + ":" +
std::to_string(alpha.toDouble());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, input);
@ -331,8 +331,8 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
};
@autoreleasepool {
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + to_string(beta.toDouble()) +
":" + to_string(alpha.toDouble());
string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" +
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* selfTensor = nil;
MPSGraphTensor* otherTensor = nil;
@ -615,8 +615,8 @@ Tensor& addr_out_mps(const Tensor& self,
};
@autoreleasepool {
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" + to_string(beta.toDouble()) +
":" + to_string(alpha.toDouble());
string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_}) + ":" +
std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1), inputShape);
MPSGraphTensor* t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2), otherShape);

View File

@ -69,7 +69,7 @@ static Tensor& mse_loss_backward_out_impl(const Tensor& grad_output,
};
@autoreleasepool {
string key = op_name + reductionToString(reduction) + ":" + to_string(grad_input.sizes()[1]) +
string key = op_name + reductionToString(reduction) + ":" + std::to_string(grad_input.sizes()[1]) +
getTensorsStringKey({input, target, grad_output});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
@ -327,8 +327,8 @@ static void nllnd_loss_backward_impl(Tensor& grad_input_arg,
}
@autoreleasepool {
string key = "nllnd_loss_backward" + getTensorsStringKey({input, grad_output, target, weight, total_weight}) +
to_string(numClasses) + ":" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" +
to_string(isTargetCasted) + ":" + reductionToString(reduction);
std::to_string(numClasses) + ":" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) +
":" + std::to_string(isTargetCasted) + ":" + reductionToString(reduction);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);
@ -463,9 +463,9 @@ static void nllnd_loss_forward_impl(Tensor& output,
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
// TODO: Make the key
string key = "nllnd_loss_forward_impl:" + to_string(ignore_index) + ":" + to_string(isWeightsArrayValid) + ":" +
reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" +
getMPSTypeString(target) + ":" + to_string(isTargetCasted) + ":" + getMPSTypeString(weight);
string key = "nllnd_loss_forward_impl:" + std::to_string(ignore_index) + ":" + std::to_string(isWeightsArrayValid) +
":" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" + getMPSTypeString(input) + ":" +
getMPSTypeString(target) + ":" + std::to_string(isTargetCasted) + ":" + getMPSTypeString(weight);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input), input_shape);
MPSGraphTensor* targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(target), target_shape);
@ -598,7 +598,7 @@ static void smooth_l1_loss_impl(const Tensor& input,
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = "smooth_l1_loss_impl:" + reductionToString(reduction) + ":" + [ns_shape_key UTF8String] + ":" +
to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
std::to_string(beta) + ":" + getMPSTypeString(input) + ":" + getMPSTypeString(target);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
// smooth_l1_loss_mps:
// ln = 0.5 * ( xn - yn ) ^ 2 / beta, if |xn - yn| < beta
@ -734,7 +734,7 @@ static void smooth_l1_loss_backward_impl(const Tensor& grad_output,
@autoreleasepool {
string key = "smooth_l1_loss_backward" + getTensorsStringKey({input, grad_output, grad_input, target}) + ":" +
reductionToString(reduction) + ":" + to_string(beta);
reductionToString(reduction) + ":" + std::to_string(beta);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input);

View File

@ -106,7 +106,7 @@ Tensor& arange_mps_out(const Scalar& start, const Scalar& end, const Scalar& ste
auto stream = getCurrentMPSStream();
auto mpsDataType = getMPSDataType(result);
@autoreleasepool {
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() {
@ -173,7 +173,7 @@ Tensor& range_mps_out(const Scalar& start, const Scalar& end, const Scalar& step
auto stream = getCurrentMPSStream();
auto mpsDataType = getMPSDataType(result);
@autoreleasepool {
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + to_string(size);
string key = "arange_mps_out" + getTensorsStringKey({result}) + ":" + std::to_string(size);
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
if (!cachedGraph) {
cachedGraph = cache_->CreateCachedGraphAs<RangeCachedGraph>(key, ^MPSCachedGraph*() {
@ -221,8 +221,8 @@ Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps,
bool start_less_end = (start.to<double>() <= end.to<double>());
@autoreleasepool {
string key =
"linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps) + to_string(start_less_end);
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + std::to_string(steps) +
std::to_string(start_less_end);
auto cachedGraph = cache_->LookUpAs<RangeCachedGraph>(key);
if (!cachedGraph) {

View File

@ -359,8 +359,8 @@ static void impl_func_norm_mps(const Tensor& input_tensor,
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t});
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" +
keepdim_info + ":" + toString(in_dtype) + ":" + to_string(castInputData);
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) + ":" +
keepdim_info + ":" + toString(in_dtype) + ":" + std::to_string(castInputData);
auto cachedGraph = LookUpOrCreateCachedGraph<MPSBinaryCachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
@ -572,7 +572,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
string op_key = (stdVarType == STANDARD_DEVIATION) ? "std_mps" : "var_mps";
NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","];
string bessel_corrected = (use_correction && correction_value) ? "unbiased " : "biased ";
string use_dim_info = (use_dim) ? "use_dim=1:" + to_string(dim_value.size()) : "use_dim=0";
string use_dim_info = (use_dim) ? "use_dim=1:" + std::to_string(dim_value.size()) : "use_dim=0";
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
string key = op_key + ":" + getTensorsStringKey(input_t) + ":" + use_dim_info + ":" + keepdim_info + ":" +
string([ns_key UTF8String]) + ":" + bessel_corrected + ":" + std::to_string(correction_value);
@ -700,7 +700,7 @@ static void min_max_out_mps(const Tensor& input_t,
auto stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + to_string(dim_);
string key = func_name + getTensorsStringKey({input_t, indices_t}) + ":" + std::to_string(dim_);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* outputTensor = nil;
@ -860,7 +860,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
@autoreleasepool {
NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key =
func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]);
func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + string([ns_key UTF8String]);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputScalarType = input_t.scalar_type();
MPSGraphTensor* inputTensor =
@ -1217,7 +1217,7 @@ TORCH_IMPL_FUNC(any_out_mps)
@autoreleasepool {
MPSShape* input_t_shape = getMPSShape(input_t);
string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" +
string key = string("any_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" +
getMPSTypeString(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSDataType input_type = getMPSDataType(input_t);
@ -1313,7 +1313,7 @@ TORCH_IMPL_FUNC(all_out_mps)
@autoreleasepool {
MPSShape* input_t_shape = getMPSShape(input_t);
string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + to_string(dim_) + ":" +
string key = string("all_out_mps:") + getMPSShapeString(input_t_shape) + ":" + std::to_string(dim_) + ":" +
getMPSTypeString(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSDataType input_type = getMPSDataType(input_t);
@ -1531,8 +1531,8 @@ static void median_out_mps(const Tensor& input_t,
auto stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
string key =
func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t);
string key = func_name + ":" + std::to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" +
getTensorsStringKey(indices_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castInputTensor =

View File

@ -108,8 +108,8 @@ TORCH_IMPL_FUNC(topk_out_mps)
// Input as placeholders
MPSShape* input_shape = getMPSShape(self);
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + to_string(k) +
":dim" + to_string(dim_) + ":largest" + to_string(largest);
string key = string("topk:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":k" + std::to_string(k) +
":dim" + std::to_string(dim_) + ":largest" + std::to_string(largest);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
@ -320,12 +320,12 @@ TORCH_IMPL_FUNC(cat_out_mps)
};
@autoreleasepool {
string key =
"cat_out_mps:" + to_string(dimension) + ":" + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
string key = "cat_out_mps:" + std::to_string(dimension) + ":" +
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
if (!all_same_dtype) {
key += getTensorsStringKey(input_tensors, true, all_same_sizes_and_stride);
} else {
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + to_string(inputs.size());
key += ":" + getMPSTypeString(input_tensors[0].scalar_type(), true) + ":" + std::to_string(inputs.size());
}
for (auto idx : skipped_tensor_indices) {
key += "," + std::to_string(idx);

View File

@ -60,8 +60,8 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
// Input as placeholders
MPSShape* input_shape = getMPSShape(self);
NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" + to_string(dim) +
":descending" + to_string(descending);
string key = string("sort:") + [ns_shape_key UTF8String] + ":" + getMPSTypeString(self) + ":dim" +
std::to_string(dim) + ":descending" + std::to_string(descending);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);

View File

@ -240,8 +240,8 @@ static void clamp_scalar_out_mps(const Tensor& input_t,
@autoreleasepool {
// the optional min/max refs could affect how we build the cached graph
string key = op_name + (has_min ? ("_min:" + to_string(min_scalar)) : "") +
(has_max ? ("_max:" + to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
string key = op_name + (has_min ? ("_min:" + std::to_string(min_scalar)) : "") +
(has_max ? ("_max:" + std::to_string(max_scalar)) : "") + "_scalar:" + getTensorsStringKey({input_t});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
if (has_min)
newCachedGraph->minTensor = [mpsGraph

View File

@ -13,32 +13,6 @@
#include <fmt/format.h>
namespace at::native {
static const std::string& getMetalType(const c10::ScalarType& t) {
// Mapping from c10::ScalarType to integral type that can be used for unary ops
static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = {
{c10::ScalarType::Half, "half"},
{c10::ScalarType::Float, "float"},
{c10::ScalarType::Long, "long"},
{c10::ScalarType::Int, "int"},
{c10::ScalarType::Short, "short"},
{c10::ScalarType::Bool, "bool"},
{c10::ScalarType::Char, "int8_t"},
{c10::ScalarType::Byte, "uint8_t"},
};
auto it = scalar_to_metal_type.find(t);
TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t);
return it->second;
}
static const std::string& getMetalType(const c10::Scalar& s) {
return getMetalType(s.type());
}
static const std::string& getMetalType(const Tensor& t) {
return getMetalType(t.scalar_type());
}
static mps::MetalShaderLibrary lib(UNARY_KERNEL_TEMPLATE, 2);
TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
@ -57,7 +31,8 @@ TORCH_IMPL_FUNC(erfinv_out_mps)(const Tensor& self, const Tensor& output_) {
}
using namespace mps;
@autoreleasepool {
auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel", {getMetalType(outputTensor), getMetalType(self)});
auto cplState = lib.getPipelineStateForFunc("erfinv_mps_kernel",
{scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(self)});
if (!self.is_contiguous()) {
inputTensor = inputTensor.contiguous();

View File

@ -36,8 +36,8 @@ static std::string getUniqueKey(const ScalarType& dtype,
const bool consecutive,
c10::optional<int64_t> dimOpt) {
return "_unique2_mps:" + getMPSTypeString(dtype) + "[" + getArrayRefString(base_shape) + "]:[" +
(dimOpt.has_value() ? to_string(dimOpt.value()) : "None") + "]:[" + to_string(return_inverse) + "]:[" +
to_string(return_counts) + "]:[" + to_string(consecutive) + "]";
(dimOpt.has_value() ? std::to_string(dimOpt.value()) : "None") + "]:[" + std::to_string(return_inverse) + "]:[" +
std::to_string(return_counts) + "]:[" + std::to_string(consecutive) + "]";
}
// dim arg not supported when non consecutive, ie sorted

View File

@ -99,7 +99,7 @@ static void upsample_out_template(const Tensor& input,
@autoreleasepool {
string key = "upsample_" + std::string(resize_mode_str) + (align_corners ? "_aligned_corners" : "") +
getTensorsStringKey({input}) + ":[" + to_string(scale_h) + "," + to_string(scale_w) + "]:[" +
getTensorsStringKey({input}) + ":[" + std::to_string(scale_h) + "," + std::to_string(scale_w) + "]:[" +
(is_backward_pass ? getArrayRefString(input_size) : "Undefined") + "]";
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {

View File

@ -42,7 +42,7 @@ static std::string getStridedKey(const ScalarType& self_dtype,
}
return (is_scatter ? "scatter:" : "gather:") + dtype_key + "[" + getArrayRefString(base_shape) + "]:[" +
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + to_string(storage_offset) + "]";
getArrayRefString(new_shape) + "]:[" + getArrayRefString(stride) + "]:[" + std::to_string(storage_offset) + "]";
}
// initializes the MTLBuffers for tensor data and runs the MPSGraph for the view op

View File

@ -764,8 +764,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t max_seqlen_batch_q = query.size(2);
const int64_t head_dim = query.size(3);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_batch_k = key.size(2);
const int64_t max_seqlen_batch_v = value.size(2);
TORCH_CHECK(
@ -806,7 +806,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_c
num_heads/*int64_t h*/,
max_seqlen_batch_q/*int64_t s_q*/,
max_seqlen_batch_k/*int64_t s_kv*/,
head_dim/*int64_t d*/,
head_dim_qk/*int64_t d_qk*/,
head_dim_v/*int64_t d_v*/,
softmax_scale/*float scaling_factor*/,
compute_logsumexp/* bool */,
is_causal/* bool */,

View File

@ -194,12 +194,11 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
const int64_t batch_size = query.size(0);
const int64_t num_heads = query.size(1);
const int64_t head_dim = query.size(3);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);
const int64_t max_seqlen_batch_q = query.size(1);
const int64_t max_seqlen_batch_k = key.size(1);
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
auto dq = at::empty_like(query);
auto dk = at::empty_like(key);
auto dv = at::empty_like(value);
@ -207,7 +206,8 @@ std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_backward_
num_heads /*int64_t h*/,
max_seqlen_batch_q /*int64_t s_q*/,
max_seqlen_batch_k /*int64_t s_kv*/,
head_dim /*int64_t d*/,
head_dim_qk /*int64_t d_qk*/,
head_dim_v /*int64_t d_v*/,
softmax_scale /*float scaling_factor*/,
is_causal /*bool is_causal*/,
dropout_p /*float dropout_probability*/,

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,46
hf_BigBird,pass,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass, 52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,46
hf_BigBird,pass,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass,52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,46
hf_BigBird,fail_accuracy,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass,52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,pass,46
hf_BigBird,pass,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass,52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -14,11 +14,11 @@ AllenaiLongformerBase,pass,9
BartForCausalLM,pass,12
BartForCausalLM,pass,6
BartForConditionalGeneration,pass,24
BartForConditionalGeneration,pass,8
@ -34,11 +34,11 @@ BlenderbotForCausalLM,eager_fail_to_run,0
BlenderbotSmallForCausalLM,pass,12
BlenderbotSmallForCausalLM,pass,6
BlenderbotSmallForConditionalGeneration,pass,24
BlenderbotSmallForConditionalGeneration,pass,8
@ -102,11 +102,11 @@ M2M100ForConditionalGeneration,pass,4
MBartForCausalLM,pass,12
MBartForCausalLM,pass,6
MBartForConditionalGeneration,pass,24
MBartForConditionalGeneration,pass,8
@ -130,23 +130,23 @@ MobileBertForQuestionAnswering,pass,3
OPTForCausalLM,pass,12
OPTForCausalLM,pass,6
PLBartForCausalLM,pass,12
PLBartForCausalLM,pass,6
PLBartForConditionalGeneration,pass,29
PLBartForConditionalGeneration,pass,8
PegasusForCausalLM,pass,12
PegasusForCausalLM,pass,6
PegasusForConditionalGeneration,pass,23
PegasusForConditionalGeneration,pass,7
@ -158,7 +158,7 @@ RobertaForQuestionAnswering,pass,5
Speech2Text2ForCausalLM,pass,12
Speech2Text2ForCausalLM,pass,6
@ -170,11 +170,11 @@ T5Small,pass,5
TrOCRForCausalLM,pass,12
TrOCRForCausalLM,pass,6
XGLMForCausalLM,pass,12
XGLMForCausalLM,pass,6

1 name accuracy graph_breaks
14 DebertaForQuestionAnswering pass 5
15 DebertaV2ForMaskedLM pass_due_to_skip 0
16 DebertaV2ForQuestionAnswering eager_1st_run_OOM 0
17 DistilBertForMaskedLM pass 5
18 DistilBertForQuestionAnswering pass 5
19 DistillGPT2 pass 5
20 ElectraForCausalLM pass 4
21 ElectraForQuestionAnswering pass 5
22 GPT2ForSequenceClassification pass 7
23 GoogleFnet pass 5
24 LayoutLMForMaskedLM pass 5
34 OPTForCausalLM pass 12 6
35 PLBartForCausalLM pass 12 6
36 PLBartForConditionalGeneration pass 29 8
37 PegasusForCausalLM pass 12 6
38 PegasusForConditionalGeneration pass 23 7
39 RobertaForCausalLM pass 5
40 RobertaForQuestionAnswering pass 5
41 Speech2Text2ForCausalLM pass 12 6
42 T5ForConditionalGeneration pass 5
43 T5Small pass 5
44 TrOCRForCausalLM pass 12 6
102
103
104
105
106
107
108
109
110
111
112
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
158
159
160
161
162
163
164
170
171
172
173
174
175
176
177
178
179
180

View File

@ -150,7 +150,7 @@ hf_Bert_large,pass,0
hf_BigBird,fail_accuracy,46
hf_BigBird,fail_accuracy,43

1 name accuracy graph_breaks
150
151
152
153
154
155
156

View File

@ -98,7 +98,7 @@ hf_Bert_large,pass,6
hf_BigBird,pass,52
hf_BigBird,pass,49

1 name accuracy graph_breaks
98
99
100
101
102
103
104

View File

@ -272,6 +272,38 @@ TEST(StaticRuntime, autogen_addr) {
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen__test_functorch_fallback) {
const std::string script = R"IR(
graph(%self: Tensor, %other: Tensor):
%bias: None = prim::Constant()
%ret = aten::_test_functorch_fallback(%self, %other)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6, 6, 6});
auto other0 = at::rand({6, 6, 6});
std::vector<IValue> args{self0, other0};
testStaticRuntime(
script,
args,
{},
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto other1 = at::rand({22, 22, 22});
std::vector<IValue> args2{self1, other1};
testStaticRuntime(
script,
args,
args2,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_argmax) {
const std::string script = R"IR(
graph(%self: Tensor, %dim: int?, %keepdim: bool):
@ -4440,6 +4472,40 @@ TEST(StaticRuntime, autogen_masked_select) {
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_nonzero_static) {
const std::string script = R"IR(
graph(%self: Tensor, %size: int, %fill_value: int):
%bias: None = prim::Constant()
%ret = aten::nonzero_static(%self, %size, %fill_value)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6, 6, 6});
auto size0 = 1;
auto fill_value0 = 1;
std::vector<IValue> args{self0, size0, fill_value0};
testStaticRuntime(
script,
args,
{},
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto size1 = 1;
auto fill_value1 = 1;
std::vector<IValue> args2{self1, size1, fill_value1};
testStaticRuntime(
script,
args,
args2,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_gather) {
const std::string script = R"IR(
graph(%self: Tensor, %dim: int, %index: Tensor, %sparse_grad: bool):
@ -7106,222 +7172,6 @@ TEST(StaticRuntime, autogen_special_multigammaln) {
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_fft_fft) {
const std::string script = R"IR(
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
%bias: None = prim::Constant()
%ret = aten::fft_fft(%self, %n, %dim, %norm)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6, 6, 6});
auto n0 = 1;
auto dim0 = 1;
auto norm0 = "forward";
std::vector<IValue> args{self0, n0, dim0, norm0};
testStaticRuntime(
script,
args,
{},
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto n1 = 1;
auto dim1 = 1;
auto norm1 = "forward";
std::vector<IValue> args2{self1, n1, dim1, norm1};
testStaticRuntime(
script,
args,
args2,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_fft_ifft) {
const std::string script = R"IR(
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
%bias: None = prim::Constant()
%ret = aten::fft_ifft(%self, %n, %dim, %norm)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6, 6, 6});
auto n0 = 1;
auto dim0 = 1;
auto norm0 = "forward";
std::vector<IValue> args{self0, n0, dim0, norm0};
testStaticRuntime(
script,
args,
{},
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto n1 = 1;
auto dim1 = 1;
auto norm1 = "forward";
std::vector<IValue> args2{self1, n1, dim1, norm1};
testStaticRuntime(
script,
args,
args2,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_fft_rfft) {
const std::string script = R"IR(
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
%bias: None = prim::Constant()
%ret = aten::fft_rfft(%self, %n, %dim, %norm)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6, 6, 6});
auto n0 = 1;
auto dim0 = 1;
auto norm0 = "forward";
std::vector<IValue> args{self0, n0, dim0, norm0};
testStaticRuntime(
script,
args,
{},
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto n1 = 1;
auto dim1 = 1;
auto norm1 = "forward";
std::vector<IValue> args2{self1, n1, dim1, norm1};
testStaticRuntime(
script,
args,
args2,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_fft_irfft) {
const std::string script = R"IR(
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
%bias: None = prim::Constant()
%ret = aten::fft_irfft(%self, %n, %dim, %norm)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6, 6, 6});
auto n0 = 1;
auto dim0 = 1;
auto norm0 = "forward";
std::vector<IValue> args{self0, n0, dim0, norm0};
testStaticRuntime(
script,
args,
{},
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto n1 = 1;
auto dim1 = 1;
auto norm1 = "forward";
std::vector<IValue> args2{self1, n1, dim1, norm1};
testStaticRuntime(
script,
args,
args2,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_fft_hfft) {
const std::string script = R"IR(
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
%bias: None = prim::Constant()
%ret = aten::fft_hfft(%self, %n, %dim, %norm)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6, 6, 6});
auto n0 = 1;
auto dim0 = 1;
auto norm0 = "forward";
std::vector<IValue> args{self0, n0, dim0, norm0};
testStaticRuntime(
script,
args,
{},
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto n1 = 1;
auto dim1 = 1;
auto norm1 = "forward";
std::vector<IValue> args2{self1, n1, dim1, norm1};
testStaticRuntime(
script,
args,
args2,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_fft_ihfft) {
const std::string script = R"IR(
graph(%self: Tensor, %n: int?, %dim: int, %norm: str?):
%bias: None = prim::Constant()
%ret = aten::fft_ihfft(%self, %n, %dim, %norm)
%cloned = aten::clone(%ret, %bias)
return (%cloned)
)IR";
auto self0 = at::rand({6, 6, 6});
auto n0 = 1;
auto dim0 = 1;
auto norm0 = "forward";
std::vector<IValue> args{self0, n0, dim0, norm0};
testStaticRuntime(
script,
args,
{},
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
auto self1 = at::rand({22, 22, 22});
auto n1 = 1;
auto dim1 = 1;
auto norm1 = "forward";
std::vector<IValue> args2{self1, n1, dim1, norm1};
testStaticRuntime(
script,
args,
args2,
/*use_allclose=*/false,
/*use_equalnan=*/false,
/*check_resize=*/true);
}
TEST(StaticRuntime, autogen_linalg_cross) {
const std::string script = R"IR(
graph(%self: Tensor, %other: Tensor, %dim: int):

View File

@ -779,4 +779,5 @@ Tensor class reference
Tensor.where
Tensor.xlogy
Tensor.xlogy_
Tensor.xpu
Tensor.zero_

View File

@ -80,6 +80,48 @@ class WorkerServerTest(TestCase):
resp = pool.request("POST", "/handler/dump_nccl_trace_pickle")
self.assertEqual(resp.status, 200)
out = pickle.loads(resp.data)
self.assertIsInstance(out, dict)
self.assertIn("version", out)
@requires_cuda
def test_dump_nccl_trace_pickle_with_params(self) -> None:
with local_worker_server() as pool:
# bad key - not lower case
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true"
)
self.assertEqual(resp.status, 400)
# unknown key
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?unknownkey=true"
)
self.assertEqual(resp.status, 400)
# bad value - not a bool
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool"
)
self.assertEqual(resp.status, 400)
# bad value - value not lowercase
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=True"
)
self.assertEqual(resp.status, 400)
# good key and value
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includecollectives=true"
)
self.assertEqual(resp.status, 200)
# good key and value
resp = pool.request(
"POST", "/handler/dump_nccl_trace_pickle?includestacktraces=true"
)
self.assertEqual(resp.status, 200)
# multiple good keys and values
resp = pool.request(
"POST",
"/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true",
)
self.assertEqual(resp.status, 200)
def test_tcp(self) -> None:
import requests

View File

@ -13,7 +13,6 @@ import shutil
import subprocess
import sys
import tempfile
import unittest
import uuid
from contextlib import closing
from unittest import mock
@ -23,12 +22,13 @@ import torch.distributed.run as launch
from torch.distributed.elastic.agent.server.api import RunResult, WorkerState
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.utils import get_socket_with_port
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.testing._internal.common_utils import (
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_WITH_DEV_DBG_ASAN,
TestCase,
)
@ -63,19 +63,7 @@ class MockException(Exception):
pass
class ElasticLaunchTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
# start a standalone, single process etcd server to use for all tests
cls._etcd_server = EtcdServer()
cls._etcd_server.start()
cls._etcd_endpoint = cls._etcd_server.get_endpoint()
@classmethod
def tearDownClass(cls):
# stop the standalone etcd server
cls._etcd_server.stop()
class ElasticLaunchTest(TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
@ -103,8 +91,6 @@ class ElasticLaunchTest(unittest.TestCase):
args = [
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=etcd",
f"--rdzv-endpoint={self._etcd_endpoint}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@ -156,8 +142,6 @@ class ElasticLaunchTest(unittest.TestCase):
args = [
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=etcd",
f"--rdzv-endpoint={self._etcd_endpoint}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@ -187,8 +171,6 @@ class ElasticLaunchTest(unittest.TestCase):
world_size = 1
args = [
f"--nnodes={nnodes}",
"--rdzv-backend=etcd",
f"--rdzv-endpoint={self._etcd_endpoint}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@ -220,8 +202,6 @@ class ElasticLaunchTest(unittest.TestCase):
os.environ["PET_NNODES"] = str(nnodes)
os.environ["PET_NPROC_PER_NODE"] = str(nproc_per_node)
os.environ["PET_RDZV_BACKEND"] = "etcd"
os.environ["PET_RDZV_ENDPOINT"] = self._etcd_endpoint
os.environ["PET_RDZV_ID"] = run_id
os.environ["PET_MONITOR_INTERVAL"] = "1"
os.environ["PET_START_METHOD"] = "spawn"
@ -250,8 +230,6 @@ class ElasticLaunchTest(unittest.TestCase):
args = [
f"--nnodes={nnodes}",
f"--nproc-per-node={nproc_type}",
"--rdzv-backend=etcd",
f"--rdzv-endpoint={self._etcd_endpoint}",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@ -272,7 +250,8 @@ class ElasticLaunchTest(unittest.TestCase):
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
def test_nproc_launch_auto_configurations(self):
@patch("torch.cuda.is_available", return_value=False)
def test_nproc_launch_auto_configurations(self, _mock1):
self._test_nproc_launch_configuration("auto", os.cpu_count())
@skip_but_pass_in_sandcastle_if(
@ -310,8 +289,9 @@ class ElasticLaunchTest(unittest.TestCase):
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=etcd",
f"--rdzv-endpoint={self._etcd_endpoint}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{get_free_port()}",
"--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@ -343,8 +323,9 @@ class ElasticLaunchTest(unittest.TestCase):
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=etcd",
f"--rdzv-endpoint={self._etcd_endpoint}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{get_free_port()}",
"--rdzv-conf='join_timeout=5,last_call_timeout=1,timeout=5'",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--max-restarts=0",
@ -376,8 +357,9 @@ class ElasticLaunchTest(unittest.TestCase):
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=etcd",
f"--rdzv-endpoint={self._etcd_endpoint}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{get_free_port()}",
"--rdzv_conf=timeout=5",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--max-restarts=0",
@ -452,8 +434,9 @@ class ElasticLaunchTest(unittest.TestCase):
args = [
f"--nnodes={min_nodes}:{max_nodes}",
f"--nproc-per-node={nproc_per_node}",
"--rdzv-backend=etcd",
f"--rdzv-endpoint={self._etcd_endpoint}",
"--rdzv-backend=c10d",
f"--rdzv-endpoint=localhost:{get_free_port()}",
"--rdzv_conf=timeout=5",
f"--rdzv-id={run_id}",
"--monitor-interval=1",
"--start-method=spawn",
@ -608,21 +591,6 @@ class ElasticLaunchTest(unittest.TestCase):
is_torchelastic_launched = fp.readline()
self.assertEqual("False", is_torchelastic_launched)
def test_init_method_tcp(self):
port = get_free_port()
with patch.object(
sys,
"argv",
[
path("bin/test_script_init_method.py"),
f"--init-method=tcp://localhost:{port}",
"--rank=0",
"--world-size=1",
],
):
runpy.run_path(sys.argv[0], run_name="__main__")
# nothing to validate, just make sure it runs
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@ -642,27 +610,6 @@ class ElasticLaunchTest(unittest.TestCase):
)
# nothing to validate, just make sure it runs
def test_init_method_env(self):
port = get_free_port()
with patch.dict(
os.environ,
{
"RANK": "0",
"WORLD_SIZE": "1",
"MASTER_ADDR": "localhost",
"MASTER_PORT": str(port),
},
), patch.object(
sys,
"argv",
[
path("bin/test_script_init_method.py"),
"--init-method=env://",
],
):
runpy.run_path(sys.argv[0], run_name="__main__")
# nothing to validate, just make sure it runs
@skip_but_pass_in_sandcastle_if(
TEST_WITH_DEV_DBG_ASAN, "test incompatible with dev/dbg asan"
)
@ -681,3 +628,7 @@ class ElasticLaunchTest(unittest.TestCase):
]
)
# nothing to validate, just make sure it runs
if __name__ == "__main__":
run_tests()

View File

@ -3662,7 +3662,8 @@ class NCCLTraceTest(NCCLTraceTestBase):
@requires_nccl()
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
@parametrize("timing_enabled", [True, False])
def test_trace_while_active(self, timing_enabled):
@parametrize("only_active", [True, False])
def test_trace_while_active(self, timing_enabled, only_active):
if self.rank == self.MAIN_PROCESS_RANK:
for c in self.children_pipes:
self.assertEqual(c.recv(), "next")
@ -3683,17 +3684,26 @@ class NCCLTraceTest(NCCLTraceTestBase):
if self.rank != 0:
pg.allreduce(a).wait()
e.synchronize()
t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace())
t = pickle.loads(
torch._C._distributed_c10d._dump_nccl_trace(onlyActive=only_active)
)
t = t["entries"]
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
if self.rank == 0:
self.assertEqual(t[-1]["collective_seq_id"], 1)
self.assertEqual(t[-1]["state"], "completed")
else:
self.assertEqual(t[-1]["collective_seq_id"], 2)
self.assertEqual(
t[-1]["state"], self.started_or_scheduled(timing_enabled)
)
if only_active:
if self.rank == 0:
self.assertEqual(len(t), 0)
else:
self.assertEqual(len(t), 1)
if not only_active:
if self.rank == 0:
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
self.assertEqual(t[-1]["collective_seq_id"], 1)
self.assertEqual(t[-1]["state"], "completed")
else:
self.assertEqual(t[-1]["profiling_name"], "nccl:all_reduce")
self.assertEqual(t[-1]["collective_seq_id"], 2)
self.assertEqual(
t[-1]["state"], self.started_or_scheduled(timing_enabled)
)
self.parent.send("next")
self.assertEqual("next", self.parent.recv())

View File

@ -1084,14 +1084,12 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase):
# far from an exhaustive check of all the expected guards, just check a couple of them.
FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
"""local "L['self']" ID_MATCH"""
).check(f"""{expected_guard_source} "L['self'].net" TYPE_MATCH""").check(
f"""{expected_guard_source} "L['self'].net" ID_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
f"""{expected_guard_source} "L['self'].net[0]" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']" ID_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['1']" ID_MATCH"""
f"""{expected_guard_source} "L['self'].net[0]" ID_MATCH"""
).run(
GUARDS_FILE.getvalue()
)

View File

@ -464,6 +464,44 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnt.frame_count, 1)
def test_assume_constant_result_on_user_defined_fn(self):
@torch._dynamo.assume_constant_result
def const_fn(n, s):
return torch.full([n], s)
def fn(B):
B = const_fn(B.size(0), 13)
X = B * 2
return X.tolist()
B_list = [8] * 32
B = torch.tensor(B_list, dtype=torch.int32)
torch._dynamo.decorators.mark_static(B, 0)
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True
self.assertEqual(
fn(B), torch.compile(fn, backend="eager", fullgraph=True, dynamic=True)(B)
)
def test_assume_constant_result_on_computation_with_graph_input(self):
@torch._dynamo.assume_constant_result
def check(y):
return y[0].item() == 1
def fn(x, y):
if check(y):
return x + 2
else:
return x + 1
y = torch.tensor([1])
x = torch.tensor(1)
self.assertEqual(fn(x, y), torch.compile(fn)(x, y))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -253,6 +253,7 @@ Target Expressions:
==> (>= 0 s1)
==> (>= 0 s2)
==> (>= 0 s3)
==> (>= 9223372036854775806 s0)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",
@ -286,14 +287,14 @@ Failure occurred while running node:
Model:
==> L['shape'][0]: 1
==> L['shape'][1]: 1
==> L['shape'][2]: 0
==> L['shape'][2]: 2
==> L['x'].size()[0]: 3
==> L['x'].storage_offset(): 0
==> L['x'].stride()[0]: 1
==> s0: 3
==> s1: 1
==> s2: 1
==> s3: 0
==> s3: 2
Assertions:
==> (== 0 L['x'].storage_offset())
@ -317,6 +318,10 @@ Target Expressions:
==> (== L['shape'][2] s3)
==> (== L['x'].size()[0] s0)
==> (> s0 0)
==> (>= 9223372036854775806 s0)
==> (>= 9223372036854775807 s1)
==> (>= 9223372036854775807 s2)
==> (>= 9223372036854775807 s3)
Failed Source Expressions:
==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""",

View File

@ -3473,6 +3473,7 @@ class GraphModule(torch.nn.Module):
]
false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
"-9223372036854775808 <= cast_symbool_to_symint_guardless(L['pred'])",
]
test_symbool_guards(
f,

View File

@ -3,7 +3,6 @@ import enum
import functools
import pprint
import re
import sys
import unittest
import warnings
@ -2860,7 +2859,7 @@ class GraphModule(torch.nn.Module):
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1)
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); _add_batch_dim_1 = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim_1 = None
batched_outputs = _autograd_grad[0]; _autograd_grad = None
chunked_result = torch._C._functorch._remove_batch_dim(batched_outputs, 3, 12, 0); batched_outputs = None
@ -2896,7 +2895,7 @@ class GraphModule(torch.nn.Module):
jac_out_in: "f32[4, 3, 4, 3, 12]" = split_2[0]; split_2 = None
unflatten: "f32[4, 3, 4, 3, 4, 3]" = jac_out_in.unflatten(-1, (4, 3)); jac_out_in = None
return (unflatten, diff_primals, o)
return (unflatten,)
""",
)
@ -2964,8 +2963,8 @@ class GraphModule(torch.nn.Module):
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func transforms don't yet support saved tensor hooks. Please open an issue with your use case.")
_grad_increment_nesting = torch._C._functorch._grad_increment_nesting()
_wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3)
child_4 = torch._C._functorch._wrap_for_grad(child_3, 3)
_wrap_for_grad_2 = torch._C._functorch._wrap_for_grad(child_2, 3); child_2 = None
child_4 = torch._C._functorch._wrap_for_grad(child_3, 3); child_3 = None
set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True)
@ -3002,7 +3001,7 @@ class GraphModule(torch.nn.Module):
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim_1)
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); _add_batch_dim_1 = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [child_4], [_add_batch_dim_1], retain_graph = True, create_graph = True); o = child_4 = _add_batch_dim_1 = None
child_5 = _autograd_grad[0]; _autograd_grad = None
child_6 = torch._C._functorch._remove_batch_dim(child_5, 3, 12, 0); child_5 = None
@ -3041,17 +3040,10 @@ class GraphModule(torch.nn.Module):
unflatten: "f32[4, 3, 3, 4, 3, 4]" = jac_out_in.unflatten(-1, (3, 4)); jac_out_in = None""",
)
# Python 3.10 and 3.11 produces slightly different graphs
if sys.version_info[:2] > (3, 10):
self.assertExpectedInline(
actual.split("\n")[-2],
""" return (unflatten, child_2, _wrap_for_grad_1, child_3, child_4, o)""",
)
else:
self.assertExpectedInline(
actual.split("\n")[-2],
""" return (unflatten, child_3, child_2, _wrap_for_grad_1, child_4, o)""",
)
self.assertExpectedInline(
actual.split("\n")[-2],
""" return (unflatten,)""",
)
@unittest.expectedFailure
def test_hessian_disable_capture(self):
@ -3160,7 +3152,7 @@ class GraphModule(torch.nn.Module):
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim)
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
batched_outputs = _autograd_grad[0]; _autograd_grad = None
chunked_result: "f32[12, 4, 3]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
@ -3172,7 +3164,7 @@ class GraphModule(torch.nn.Module):
split_1: "f32[12, 4, 3]" = split[0]; split = None
output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None
return (output_input, diff_primals, o)
return (output_input,)
""",
)
@ -3243,7 +3235,7 @@ class GraphModule(torch.nn.Module):
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim)
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
batched_outputs = _autograd_grad[0]; _autograd_grad = None
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
@ -3255,7 +3247,7 @@ class GraphModule(torch.nn.Module):
split_1: "f32[12, 3, 4]" = split[0]; split = None
output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None
return (output_input, diff_primals, o)
return (output_input,)
""",
)
@ -3328,7 +3320,7 @@ class GraphModule(torch.nn.Module):
_vjp_treespec_compare = torch._functorch.eager_transforms._vjp_treespec_compare(o, _add_batch_dim)
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); _add_batch_dim = None
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([o], [diff_primals], [_add_batch_dim], retain_graph = True, create_graph = True); o = diff_primals = _add_batch_dim = None
batched_outputs = _autograd_grad[0]; _autograd_grad = None
chunked_result: "f32[12, 3, 4]" = torch._C._functorch._remove_batch_dim(batched_outputs, 1, 12, 0); batched_outputs = None
@ -3340,7 +3332,7 @@ class GraphModule(torch.nn.Module):
split_1: "f32[12, 3, 4]" = split[0]; split = None
output_input: "f32[3, 4, 3, 4]" = split_1.view((3, 4, 3, 4)); split_1 = None
return (output_input, aux_1, diff_primals, o)
return (output_input, aux_1)
""",
)
@ -3776,7 +3768,7 @@ class GraphModule(torch.nn.Module):
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting()
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (grad_input_1, y)
return (y, grad_input_1)
""",
)
@ -5187,10 +5179,10 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_self_buffers_tensor_constant0_: "f32[3, 3, 3]"):
l_self_buffers_tensor_constant0_ = L_self_buffers_tensor_constant0_
def forward(self, L_self_tensor_constant0: "f32[3, 3, 3]"):
l_self_tensor_constant0 = L_self_tensor_constant0
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_buffers_tensor_constant0_); l_self_buffers_tensor_constant0_ = None
alias_default: "f32[3, 3, 3]" = torch.ops.aten.alias.default(l_self_tensor_constant0); l_self_tensor_constant0 = None
sin_default: "f32[3, 3, 3]" = torch.ops.aten.sin.default(alias_default)
@ -5209,16 +5201,16 @@ class GraphModule(torch.nn.Module):
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_: "f32[3, 3, 3]", L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
l_self_modules_fx_const_folded_attrs_parameters_0_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_0_
l_self_modules_fx_const_folded_attrs_parameters_1_ = L_self_modules_FX_CONST_FOLDED_ATTRS_parameters_1_
def forward(self, getattr_L_self_FX_CONST_FOLDED_ATTRS_0_: "f32[3, 3, 3]", getattr_L_self_FX_CONST_FOLDED_ATTRS_1_: "f32[3, 3, 3]", L_flat_tangents_1_: "f32[3, 3, 3]"):
getattr_l_self_fx_const_folded_attrs_0_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_0_
getattr_l_self_fx_const_folded_attrs_1_ = getattr_L_self_FX_CONST_FOLDED_ATTRS_1_
l_flat_tangents_1_ = L_flat_tangents_1_
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, l_self_modules_fx_const_folded_attrs_parameters_0_); l_self_modules_fx_const_folded_attrs_parameters_0_ = None
_new_zeros_with_same_feature_meta_default: "f32[3, 3, 3]" = torch.ops.aten._new_zeros_with_same_feature_meta.default(l_flat_tangents_1_, getattr_l_self_fx_const_folded_attrs_0_); getattr_l_self_fx_const_folded_attrs_0_ = None
copy__default: "f32[3, 3, 3]" = torch.ops.aten.copy_.default(_new_zeros_with_same_feature_meta_default, l_flat_tangents_1_); _new_zeros_with_same_feature_meta_default = l_flat_tangents_1_ = None
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, l_self_modules_fx_const_folded_attrs_parameters_1_); copy__default = l_self_modules_fx_const_folded_attrs_parameters_1_ = None
mul_tensor: "f32[3, 3, 3]" = torch.ops.aten.mul.Tensor(copy__default, getattr_l_self_fx_const_folded_attrs_1_); copy__default = getattr_l_self_fx_const_folded_attrs_1_ = None
return (mul_tensor,)
""",
)

View File

@ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match:
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Right: {0: 0, 1: 1}
==> var_to_range: values don't match.
> Left: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Right: {}
==> var_to_sources: values don't match.
> Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=<TensorProperty.SIZE: 0>, idx=1)]}
@ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match:
> Left: 2
> Right: 0
==> var_to_range: values don't match.
> Left: {u0: VR[-int_oo, int_oo], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Left: {u0: VR[-9223372036854775808, 9223372036854775807], u1: VR[0, 1], zuf0: VR[-oo, oo]}
> Right: {}
""",
)
@ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match:
> Left: {s0: 3}
> Right: {}
==> var_to_range: values don't match.
> Left: {s0: VR[3, 3], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s0: VR[3, 3], s1: VR[2, 9223372036854775806]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
""",
)
self._replay_and_check(main)
@ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match:
> Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
> Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_}
==> var_to_range: values don't match.
> Left: {s0: VR[3, int_oo], s1: VR[2, int_oo]}
> Right: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
> Left: {s0: VR[3, 9223372036854775806], s1: VR[2, 9223372036854775806]}
> Right: {s0: VR[2, 9223372036854775806], s1: VR[2, 9223372036854775806]}
""",
)
self._replay_and_check(main)

View File

@ -101,6 +101,15 @@ class TestModelOutput(torch._dynamo.test_case.TestCase):
self._common(fn, 2)
@maybe_skip
def test_mo_getattr_missing(self):
def fn(obj: BaseModelOutput):
if getattr(obj, "asdf", None) is not None:
obj.asdf += 1
return obj.attentions + 1
self._common(fn, 1)
@maybe_skip
def test_mo_getitem(self):
def fn(obj: BaseModelOutput):
@ -166,6 +175,59 @@ class TestModelOutput(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
@maybe_skip
def test_mo_init2(self):
# this ModelOutput subclass runs a different __post_init__ codepath
@dataclasses.dataclass
class MyDataClass(ModelOutput):
x: torch.FloatTensor = None
def fn(x):
obj = MyDataClass(x=x)
return obj
inp = torch.randn(3, 3)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
self.assertEqual(fn(inp).x, opt_fn(inp).x)
@maybe_skip
def test_mo_init_with_disable(self):
# Can result in "non-function or method super: <slot wrapper '__setattr__' of 'object' objects>"
# graph breaks (although it may not be the first)
# Minimal repro for https://github.com/pytorch/pytorch/issues/126028
@dataclasses.dataclass
class MyDataClass(ModelOutput):
x: torch.FloatTensor = None
@torch._dynamo.disable(recursive=False)
def fn(x):
return MyDataClass(x=x)
inp = torch.randn(3, 3)
opt_fn = torch._dynamo.optimize("eager")(fn)
self.assertEqual(fn(inp).x, opt_fn(inp).x)
@maybe_skip
def test_mo_newkey(self):
obj = BaseModelOutput()
def fn(obj):
return obj["wwww"] + 1
inp = torch.randn(3, 3)
obj["wwww"] = inp
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
self.assertEqual(fn(obj), opt_fn(obj))
@maybe_skip
def test_mo_from_outside(self):
def fn(obj):
return obj.attentions + 1
obj = BaseModelOutput(attentions=torch.randn(3, 3))
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
self.assertEqual(fn(obj), opt_fn(obj))
@maybe_skip
def test_HF_bert_model_output(self):
class BertPooler(torch.nn.Module):

View File

@ -22,7 +22,6 @@ from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import expectedFailureDynamic, same
from torch._dynamo.utils import ifdynstaticdefault
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import Parameter, UninitializedParameter
@ -1108,37 +1107,6 @@ class UnspecNonInlinableToplevelModule(torch.nn.Module):
return self.m(x)
class ModuleWithIntAttr(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(4, 4)
self.step = 10
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
self.step += 1
return self.layer(x) + self.step
class UnspecInlinableModule(torch.nn.Module):
torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule
def forward(self, x):
return torch.sin(x)
class UnspecModuleWithIntAttr(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = UnspecInlinableModule()
self.step = 10
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
self.step += 1
return self.layer(x) + self.step
def make_test(fn, expected_ops=None):
def test_fn(self):
return torch._dynamo.testing.standard_test(
@ -1392,31 +1360,6 @@ class NNModuleTests(torch._dynamo.test_case.TestCase):
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
self.assertTrue(torch._dynamo.testing.same(out1, out_post))
def test_nn_module_unspec_int_attr(self):
for module_class in [ModuleWithIntAttr, UnspecModuleWithIntAttr]:
mod = module_class()
cnt = torch._dynamo.testing.CompileCounter()
opt_mod = torch.compile(backend=cnt)(copy.deepcopy(mod))
x = torch.randn(3, 4)
# Compiling self.step as static.
ref1 = mod(x)
res1 = opt_mod(x)
self.assertTrue(torch.allclose(ref1, res1))
self.assertEqual(cnt.frame_count, 1)
# Compiling self.step as dynamic.
ref2 = mod(x)
res2 = opt_mod(x)
self.assertTrue(torch.allclose(ref2, res2))
self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))
# No re-compilation!
ref3 = mod(x)
res3 = opt_mod(x)
self.assertTrue(torch.allclose(ref3, res3))
self.assertEqual(cnt.frame_count, ifdynstaticdefault(2, 1))
# RuntimeError: SymIntArrayRef expected to contain only concrete integers
@expectedFailureDynamic
def test_lazy_module1(self):

View File

@ -201,19 +201,6 @@ class TestDynamismExpression(TestCase):
dynamic_shapes={"x": {0: dim_x}},
)
def test_export_slice_maxsize(self):
class Slice(torch.nn.Module):
def forward(self, *args):
return torch.ops.aten.slice.Tensor(*args)
inp = (torch.rand((10, 3, 224, 224)), 0, 0, 9223372036854775807)
dynamic_shapes = (({0: Dim("dim")}, None, None, None),)
torch.export.export(
Slice(),
inp,
dynamic_shapes=dynamic_shapes,
)
def test_export_constraints_error(self):
class ConflictingConstraints(torch.nn.Module):
def forward(self, x):
@ -5196,7 +5183,7 @@ def forward(self, x, y):
}
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
def test_disable_forced_specializations_ok(self):
def test_disable_forced_specializations(self):
# check that _disable_forced_specializations and _allow_complex_guards_as_runtime_asserts flags
# both behave correctly, avoiding forced specializations and deferring to runtime.
# case 1: modulo guards

View File

@ -312,6 +312,31 @@ class TestUnflatten(TestCase):
export_module.module(), unflattened, (torch.randn((2, 3)),)
)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
def test_unflatten_preserve_with_unused_input(self):
class M1(torch.nn.Module):
def forward(self, x, a, b):
return x + a, b
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.m1 = M1()
def forward(self, x, y):
a, b = torch.topk(y, 2)
return self.m1(x, a, b)[0]
ep = torch.export.export(
M(),
(torch.randn(2), torch.randn(5)),
preserve_module_call_signature=("m1",),
strict=False,
)
ep.graph.eliminate_dead_code()
unflattened = unflatten(ep)
self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5)))
def test_unflatten_wrong_input(self):
class Mod(torch.nn.Module):
def __init__(self):

View File

@ -0,0 +1,53 @@
# Owner(s): ["module: fx"]
import unittest
from typing import Mapping
import torch
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
from torch.testing._internal.common_utils import TestCase
class DummyDevOperatorSupport(OperatorSupport):
def is_node_supported(
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
) -> bool:
return True
class DummyPartitioner(CapabilityBasedPartitioner):
def __init__(self, graph_module: torch.fx.GraphModule):
super().__init__(
graph_module,
DummyDevOperatorSupport(),
allows_single_node_partition=True,
)
class AddModule(torch.nn.Module):
def forward(self, x):
y = torch.add(x, x)
z = torch.add(y, x)
return z
class TestPartitionerOrder(TestCase):
# partitoner test to check graph node order
def test_partitioner_order(self):
m = AddModule()
traced_m = torch.fx.symbolic_trace(m)
partions = DummyPartitioner(traced_m).propose_partitions()
partion_nodes = [list(partition.nodes) for partition in partions]
node_order = [n.name for n in partion_nodes[0]]
for _ in range(10):
traced_m = torch.fx.symbolic_trace(m)
new_partion = DummyPartitioner(traced_m).propose_partitions()
new_partion_nodes = [list(partition.nodes) for partition in new_partion]
new_node_order = [n.name for n in new_partion_nodes[0]]
self.assertTrue(node_order == new_node_order)
if __name__ == "__main__":
unittest.main()

View File

@ -3761,6 +3761,20 @@ class CPUReproTests(TestCase):
exactly=True,
).run(code)
def test_repeated_exp(self):
def fn(x):
y = x.sigmoid()
return y + 1, y.sum(-1)
x = torch.randn(1000, 1000)
opt_fn = torch.compile(fn)
_, code = run_and_get_cpp_code(opt_fn, x)
FileCheck().check_count(
".exp()",
1,
exactly=True,
).run(code)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -5537,6 +5537,14 @@ class CommonTemplate:
for dtype in all_types():
self.common(fn, (make_tensor(8, dtype=dtype, device=self.device),))
def test_full_boolean(self):
def fn(n):
x = torch.full((1,), n >= 1024, device=self.device)
return x, x + 1
self.common(fn, (1024,))
self.common(fn, (1023,))
def test_index1(self):
def fn(a, b, c):
return aten.index(a, [b, c])

Some files were not shown because too many files have changed in this diff Show More