Compare commits

..

126 Commits

Author SHA1 Message Date
2ce56de80e Remove few xfails 2025-05-22 15:18:11 -07:00
34cd5614c5 Fix lint 2025-05-22 15:16:53 -07:00
15d7f6ac2b clean up 2025-05-22 17:32:21 -04:00
7b80b3fd13 Apply suggestions from code review 2025-05-22 14:19:29 -07:00
b0b1902739 Update aten/src/ATen/native/mps/operations/Pooling.mm 2025-05-22 14:19:06 -07:00
1d29dc5d9c fix test_max_pool3d 2025-05-22 17:04:15 -04:00
fe518636a6 update 2025-05-22 16:33:09 -04:00
765dd32545 One is expected to return Tensor by reference from function 2025-05-22 13:16:58 -07:00
b9ca9918ba [BE] Do not call explicit constructor
Compiler should do the work for you
2025-05-22 13:16:18 -07:00
003540fcb6 Fix build 2025-05-22 13:12:38 -07:00
a7f788143e [MPS] Implement max_pool3d_with_indices 2025-05-22 15:59:53 -04:00
befb5bd52a [dynamic shapes] simplify int(x / y) pattern (#153477)
Fixes #138853

Summary: Converts `TruncToInt(IntTrueDiv(x / y))` to `x // y` if divisible, helps detect symint specializations where we didn't previously

Differential Revision: D74664734

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153477
Approved by: https://github.com/bobrenjc93
2025-05-16 17:32:15 +00:00
3aa84775e7 [hipify] Replace cuda error cudaErrorContextIsDestroyed (#153576)
Summary: The cuda symbol the cuda symbol cudaErrorContextIsDestroyed is not converted to hipErrorContextIsDestroyed. Add this convertion

Test Plan: CI

Differential Revision: D74542735

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153576
Approved by: https://github.com/xw285cornell, https://github.com/cyyever
2025-05-16 16:19:42 +00:00
a060f3d272 Rewrite autograd producer consumer stream sync logic (#151079)
Also see previous work https://github.com/pytorch/pytorch/pull/142097

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151079
Approved by: https://github.com/albanD
2025-05-16 15:42:22 +00:00
2ce0b66db8 [dynamo] Make OptimizedModule more robust in attribute reads and writes (#153637)
Fixes #138157.

Differential Revision: [D74834872](https://our.internmc.facebook.com/intern/diff/D74834872)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153637
Approved by: https://github.com/williamwen42
2025-05-16 15:17:07 +00:00
f66a159db5 [Set] Raise TypeError if set is called with the wrong number of arguments (#152990)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152990
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904, #152901, #152902, #152903, #152905, #152906, #152989, #152907, #152908
2025-05-16 14:28:32 +00:00
5a0ca65555 [Set] Add correct set/frozenset __init__ behavior (#152908)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152908
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904, #152901, #152902, #152903, #152905, #152906, #152989, #152907
2025-05-16 14:28:32 +00:00
053025494f [Set] Raise KeyError on empty set.pop() (#152907)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152907
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904, #152901, #152902, #152903, #152905, #152906, #152989
2025-05-16 14:28:32 +00:00
5964cb5eb1 [Set] Update set.union and set.update to support *args (#152989)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152989
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904, #152901, #152902, #152903, #152905, #152906
2025-05-16 14:28:32 +00:00
4759922c5e [Set] Add set.intersection(_update) (#152906)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152906
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904, #152901, #152902, #152903, #152905
2025-05-16 14:28:32 +00:00
ca96d55322 [Set] Add set.difference(_update) (#152905)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152905
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904, #152901, #152902, #152903
2025-05-16 14:28:32 +00:00
5c6830ced0 [Set] Raise KeyError if elem not contained in the set (#152903)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152903
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904, #152901, #152902
2025-05-16 14:28:32 +00:00
574f4c507a [Set] Add set.issubset and set.issuperset (#152902)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152902
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904, #152901
2025-05-16 14:28:32 +00:00
5926b7a38f [Set] Add set.symmetric_difference(_update) (#152901)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152901
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988, #152904
2025-05-16 14:28:32 +00:00
fe51ce62ca [Set] Raise TypeError if number of arguments mismatch (#152904)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152904
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987, #152988
2025-05-16 14:28:32 +00:00
481c345f49 [Set] Raise TypeError if argument is unhashable (#152988)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152988
Approved by: https://github.com/anijain2305
ghstack dependencies: #150792, #152987
2025-05-16 14:28:32 +00:00
cf7021a0ee [Set] Handle exception in ConstantVariable operation (#152987)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152987
Approved by: https://github.com/williamwen42, https://github.com/anijain2305
ghstack dependencies: #150792
2025-05-16 14:28:32 +00:00
477f13c3fb [Set] Add CPython set tests (#150792)
Tests:
* test_set.py

This PR adds test_set.py from the CPython 3.13 branch and ~400 files to test/dynamo_expected_failures. Most of these are expected to be fixed in upcoming PRs. Only minimal changes were made to test_set.py to enable compilation with Dynamo using the PYTORCH_TEST_WITH_DYNAMO=1 environment variable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150792
Approved by: https://github.com/anijain2305
2025-05-16 14:28:32 +00:00
6592086ac3 Add metal kernel for log ops (#153398)
Move unary log ops to metal kernels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153398
Approved by: https://github.com/kulinseth, https://github.com/malfet
2025-05-16 14:25:28 +00:00
8ca985b365 [Break XPU] Skip newly added test case on XPU that failed because torch._C._scatter not implemented. (#153685)
Fixes #153608
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153685
Approved by: https://github.com/malfet
2025-05-16 14:15:50 +00:00
9ccd601a14 [easy] Fix endif comments in functional_base.h (#153696)
The first one of these confused me on #152388. Happened to notice the second.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153696
Approved by: https://github.com/Skylion007, https://github.com/malfet
2025-05-16 14:08:41 +00:00
3443627e07 Revert "[BE]: Enable RUFF TRY400 rule - log.exception (#153473)"
This reverts commit 4f4ecc583e0f48ad2d062a53bf91c61ab40b4948.

Reverted https://github.com/pytorch/pytorch/pull/153473 on behalf of https://github.com/jeanschmidt due to seems to have broken internal signals, @albanD may I count on you to help the author merge his PR? D74837988 ([comment](https://github.com/pytorch/pytorch/pull/153473#issuecomment-2886017075))
2025-05-16 08:29:26 +00:00
86c6f71ddb Revert "[Ez][BE]: Remove accidental classvar (#153540)"
This reverts commit e0dece510b703376d50a5d6536be6c601ca67d9e.

Reverted https://github.com/pytorch/pytorch/pull/153540 on behalf of https://github.com/jeanschmidt due to Broken internal tests, @albanD may you help the author get his PR merged? D74804063 ([comment](https://github.com/pytorch/pytorch/pull/153540#issuecomment-2886011101))
2025-05-16 08:26:37 +00:00
4d073af58c Revert "[inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)"
This reverts commit 725bbb6b5fffa2f2d219a0692ed27e376c9dd48a.

Reverted https://github.com/pytorch/pytorch/pull/152353 on behalf of https://github.com/jeanschmidt due to seems to have broken a few internal tests, @jansel may you help the author get his PR merged? ([comment](https://github.com/pytorch/pytorch/pull/152353#issuecomment-2885997862))
2025-05-16 08:20:39 +00:00
741539a790 Split out second pass of LayerNorm for profiler attribution reasons (#153578)
Summary:
Split out second pass of LayerNorm so it's more likely to show up in
profiler output. In my testing with perf, the samples from the lambda in the
current implementation are attributed somewhat haphazardly.

Differential Revision: D74181627

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153578
Approved by: https://github.com/hl475
2025-05-16 08:07:13 +00:00
a9adc9a9b6 [Linter] Add linter to detect device-bias hard code in test cases. (#152948)
Since XPU does not gate community pull requests, we’ve observed that contributors often hardcode "cuda" in functions decorated with @requires_gpu() when adding new test cases. This causes the tests to fail on XPU and breaks XPU CI.
This PR adds a linter to detect such issues automatically. An example is shown below.

```
  Error (TEST_DEVICE_BIAS) [device-bias]
    `@requires_gpu` function should not hardcode device='cuda'

        11670  |                .contiguous()
        11671  |            )
        11672  |
    >>> 11673  |        inp = torch.rand((64, 64), device="cuda") * 2 - 1
        11674  |        boundaries = torch.tensor([-0.9, -0.8, 0.1, 0.2, 0.5, 0.9])
        11675  |
        11676  |        self.common(fn, (inp, boundaries), check_lowp=False)

  Error (TEST_DEVICE_BIAS) [device-bias]
    `@requires_gpu` function should not hardcode .cuda() call

        11700  |            self.assertEqual(ref, res)
        11701  |
        11702  |            for offset2 in (0, 1, 2, 3, 4):
    >>> 11703  |                base2 = torch.randn(64 * 64 + 64, dtype=torch.float32).cuda()
        11704  |                inp2 = torch.as_strided(base2, (64, 64), (64, 1), offset2)
        11705  |                ref2 = fn(inp2)
        11706  |                res2 = fn_c(inp2)

  Error (TEST_DEVICE_BIAS) [device-bias]
    `@requires_gpu` function should not hardcode torch.device('cuda:0')

        11723  |            return x.sin() + x.cos()
        11724  |
        11725  |        base = torch.randn(
    >>> 11726  |            64 * 64 + 64, dtype=torch.float32, device=torch.device("cuda:0")
        11727  |        )
        11728  |
        11729  |        inp1 = torch.as_strided(base, (32, 32), (32, 1), 4)

  Error (TEST_DEVICE_BIAS) [device-bias]
    `@requires_gpu` function should not hardcode .to('cuda') call

        11771  |            torch.manual_seed(42)
        11772  |            base = torch.randn(64 * 64 + 64, dtype=torch.float32, device=self.device)
        11773  |            torch.manual_seed(42)
    >>> 11774  |            base_ref = torch.randn(64 * 64 + 64, dtype=torch.float32).to("cuda")
        11775  |
        11776  |            inp = torch.as_strided(base, size, stride, offset)
        11777  |            inp_ref = torch.as_strided(base_ref, size, stride, offset)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152948
Approved by: https://github.com/EikanWang, https://github.com/cyyever, https://github.com/malfet, https://github.com/jansel
2025-05-16 08:03:54 +00:00
658d17dfb5 [ONNX] Add test for decomp_table update (#153671)
Added a test to strengthen the case for cherry-picking #153168. The original PR didn’t include this test since the fix for decomp_table and the registry was already covered by existing tests. However, it's reasonable to include a dedicated test for the specific issue (https://github.com/pytorch/pytorch/issues/150367 ) when considering the cherry-pick.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153671
Approved by: https://github.com/justinchuby
2025-05-16 08:00:16 +00:00
3fe42d4d5d [export] Dynamo symint support (#152677)
Basically adds native _IntWrapper support to dynamo. Here's my process of trying to make symint input support work on dynamo, and how I ended up with this approach [(doc)](https://docs.google.com/document/d/1GvNRQd8BnxlMay_hrEVgEta6VUeUW_hcFeRuB7q1nDY/edit?tab=t.0).

What I did was, before passing inputs to dynamo.export, I first wrap them with a class, `_IntWrapper`. When processing dynamic shapes, I will then add the corresponding dynamic shape specification to the `dynamism` field stored on the `_IntWrapper`. If there is no dynamism specified, then this will get unwrapped back to an integer. When dynamo tracing, when we encounter an `_IntWrapper`, we will convert this to a symint if the dynamism was specified as `Dim.DYNAMIC/AUTO`. Dynamo will then trace a graph that contains symint inputs, which will get passed to AOTAutograd and so on.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152677
Approved by: https://github.com/pianpwk
2025-05-16 07:51:50 +00:00
d965fa2c4b [CUDA][cuBLAS] Remove IS_ARM64 skip in test_matmul_cuda.py (#153660)
Original skip seems stale and the test appears to run fine on Grace + Hopper and Grace + Blackwell

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153660
Approved by: https://github.com/Skylion007
2025-05-16 07:31:16 +00:00
1503b3f897 [DSD] Don't pop tensors if they are on Meta device (#153185)
DSD currently will pop tensors if these tensors are on Meta device. This forbid the use cases that users would like to let DCP to directly initialize the tensors when loading.

This PR also removes test/distributed/checkpoint/e2e/test_pipeline.py which is based on the above feature that is not realistic and is not used anywhere.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153185
Approved by: https://github.com/mori360
2025-05-16 07:18:39 +00:00
1a722f62c2 [Quant][X86] add an op to compute uint8 batch norm 2d (#152811)
**Summary**
This PR adds a new op, `onednn.qbatch_norm2d`, which accepts uint8 inputs on CPU device (instead of QuantizedCPU).
The new ops are implemented with AVX512 instructions and it provides similar performance as its counterpart for QuantizedCPU device `quantized.batch_norm2d`.
The new op supports output dtypes other than uint8 (fp32, fp16 and bf16 are supported).

**Test plan**
```
pytest test/quantization/core/test_quantized_op.py -k test_int8_batch_norm_onednn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152811
Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168, https://github.com/jgong5
ghstack dependencies: #152411
2025-05-16 06:13:40 +00:00
7e16cb99b6 [FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation (#153357)
Fixes #147336

## Context

NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.

Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.

To summarize:

In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.

This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).

i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.

## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs

## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.

Before fix:

```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```

After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
2025-05-16 04:56:50 +00:00
459ce6c12a [export] Flatten frame local logs (#153627)
Summary:
Some new errors have been showing up on the PT2 dashboard with
```
Invalid type for lengths: Expected BlobReference or torch.Tensor, got: Tensor(shape: torch.Size([10]), stride: (1,), storage_offset: 0)
```
This is caused by [this piece of code](https://fburl.com/code/5nbi9on7) which maps over a set of nodes (in this case type `IDListFeatureListField`) and turns the results into strings to be displayed later. However during pytree.tree_map we call pytree.tree_unflatten which will call the class's init function, which calls `assert_blob` (https://fburl.com/code/h3ainrn9). Because we've mapped over the values and converted them to strings, the assert_blob fails.

I initially thought to disable the assert_blob while tracing (D74684309) but then I think we should actually flatten the list first. Because tlparse will expect just a string out outputs instead of the actual structure.

Test Plan: `buck2 run mode/opt sigmoid/inference/ts_migration:pt2i_readiness_main -- --test_suite ads_all --mode test_full_model --model_id 542947220` fails with something else 😅

Differential Revision: D74744326

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153627
Approved by: https://github.com/yiming0416
2025-05-16 04:45:09 +00:00
7ed377f577 Reapply "Delete TorchScript based Android demo app and point to ExecuTorch (#153633)" (#153656)
This reverts commit ae0e8f0c7316addab3f415dc767a9d34f58b0dae.

Keep android/libs/fbjni because it's being used by other components of
PyTorch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153656
Approved by: https://github.com/malfet
2025-05-16 04:35:42 +00:00
56e1c236bf [Dynamo] Catch unserialisable NN modules (#153503)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153503
Approved by: https://github.com/c00w, https://github.com/jansel
2025-05-16 02:55:28 +00:00
d1f1ff8610 [ddp] propagate use_python_reducer to C++ reducer (#152735)
C++ Reducer is silently incorrect under CA, its implementation is no-oping the collective. I'm guessing that it was no-op'd because in DDP + python reducer, the C++ reducer is still being initialized.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152735
Approved by: https://github.com/fegin
ghstack dependencies: #153300, #152689
2025-05-16 01:38:03 +00:00
1b4749f748 [ca][dtensor] run real PG dtensor tests under CA (#152689)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152689
Approved by: https://github.com/bdhirsh
ghstack dependencies: #153300
2025-05-16 01:38:03 +00:00
5aea57d653 [ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)
I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`.

This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail.

This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA.

FIXES https://github.com/pytorch/pytorch/issues/127115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300
Approved by: https://github.com/jansel
2025-05-16 01:37:48 +00:00
cyy
9d3b6ee4c1 [submodule] Update gtest to v1.17.0 (#153618)
And remove some outdated CMake code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153618
Approved by: https://github.com/malfet
2025-05-16 01:24:19 +00:00
d1dd2c1fc8 gloo: cuda (#153406)
This enables Gloo CUDA when used with a backend that supports GPUDirect which currently is only the IBVERBS backend.

This requires some changes to Gloo which are in https://github.com/pytorch/gloo/pull/441

Since we're now depending on gloo_cuda we need to split ProcessGroupGloo into two pieces, one with the CPU bits (libtorch_cpu) and one with CUDA kernels in libtorch_cuda. This unfortunately requires some major refactoring as some CPU code is shared across both.

The gloo submodule is updated to depend on the new Gloo changes

Test plan:

```py
import os
import time

transport = "TCP"
#transport = "IBVERBS"

os.environ["GLOO_DEVICE_TRANSPORT"] = transport
rank = int(os.environ["RANK"])
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)

ibv = "mlx5_0:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_9:1,mlx5_10:1,mlx5_11:1".split(",")[rank]
ibv_name, ibv_port = ibv.split(":")
os.environ["TORCH_GLOO_IBV_NAME"] = ibv_name
os.environ["TORCH_GLOO_IBV_PORT"] = ibv_port
os.environ["TORCH_GLOO_IBV_INDEX"] = "3"

import torch
import torch.distributed as dist

dist.init_process_group("gloo")

rank = dist.get_rank()

# initial sanity check
#device = "cpu"
#t = torch.zeros(10, device=device)
#dist.all_reduce(t)
#print("sanity complete")

device = "cpu"

iters = 10
warmup_iters = 2

for nelem in [10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000]:
    t = torch.zeros(nelem, device=device)

    torch.cuda.current_stream().synchronize()
    for i in range(warmup_iters):
        dist.all_reduce(t)

    torch.cuda.current_stream().synchronize()

    start = time.perf_counter()

    for i in range(iters):
        dist.all_reduce(t)

    torch.cuda.current_stream().synchronize()

    dur = (time.perf_counter() - start)
    qps = iters/dur

    bandwidth_gb = t.nbytes * iters / dur / 1e9

    gb = t.nbytes / 1e9

    if rank == 0:
        print(f"{transport=} {device=} {iters=} {nelem=} {qps=} {gb=} {bandwidth_gb=}\n", end="")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153406
Approved by: https://github.com/fduwjj
2025-05-16 01:13:13 +00:00
ab757dcddc [MPS][Testing] Add GoogleFnet, YituTechConvBert and Super_SloMo to benchmarks (#153658)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153658
Approved by: https://github.com/atalman, https://github.com/ZainRizvi, https://github.com/cyyever
ghstack dependencies: #153657
2025-05-16 01:09:31 +00:00
754b758ea1 [BE] Extend empty_gpu_cache to mps (#153657)
And replace `if: elif:` with `getattr()`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153657
Approved by: https://github.com/atalman, https://github.com/wdvr, https://github.com/ZainRizvi
2025-05-16 01:08:54 +00:00
2489b6470b [c10d] Allow split_group to work with non nccl backends (#152175)
Summary:
Currently things are hardcoded to only work with nccl backend. Extend it
to allow NCCL + custom plugin backend.

The split-specific methods/attributes have not been added to the base
Backend and Options as some of them are specific to backend implementations.
Instead, explicit checks have been added to the split_group method for the
expected methods and attributes.

I am open to making them part of base Backend based if folks prefer.

Test Plan:
CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152175
Approved by: https://github.com/shuqiangzhang, https://github.com/kwen2501
2025-05-16 00:15:29 +00:00
cb5f31a4a1 Fix fake tensor caching when output has unbacked (#153034)
We handle fake tensor caching in two ways:
1. If the inputs have no symbols (SymInt, etc) then we cache on the FakeTensorMode.
2. If the inputs have symbols then we cache on the ShapeEnv.

This way the symbols in the inputs and outputs are associated with the guards in place at the time of the call.

However - it's possible to have an op where there are no symbols in the inputs but there is an unbacked symbol in the output.  In this case we shouldn't cache at all because what would that really mean?

So this PR changes the caching behavior so that if there's a symbol in the output which doesn't come in some way from the input then we refuse to cache that op.

Added a test which checks for this case.

While in there I also did a couple other related changes:
1. Added negative caching - if we see that an (op, args) failed to cache previously we don't even bother trying to cache it again.
2. Reworked the inner behavior of _cached_dispatch_impl a little to make it more clear which bits we expect to be able to throw _BypassDispatchCache and add some comments.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153034
Approved by: https://github.com/masnesral, https://github.com/tugsbayasgalan
2025-05-15 23:18:52 +00:00
e7a40fb301 [Async TP] Fix dim swapping before reduction in fused_scaled_matmul_reduce_scatter (#153595)
## Summary
- The unit test `pytest test/distributed/test_symmetric_memory.py -k test_fused_scaled_matmul_reduce_scatter_scatter` was not running for some reason when #149247 was merged, giving false green CI signals. When it was ran manually recently, the test failed, highlighting a bug causing incorrect numerics when `scatter_dim=1`.
- This PR fixes the bug, which was related to how we swap dims 0<=>scatter_dim at the beginning of the custom op (for more efficient cross-device data movement I believe), then swap it back prior to reduction.

## Test plan
- I confirmed the unit test `pytest test/distributed/test_symmetric_memory.py -k test_fused_scaled_matmul_reduce_scatter_scatter` is now passing.
- I confirmed e2e training w/ torchtitan looks good ([logs](https://www.internalfb.com/phabricator/paste/view/P1812054188))
- I analyzed the tlparse to verify the fused_all_gather_matmul and fused_scaled_matmul_reduce_scatter both appear at least once in the post grad graphs ([tlparse](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/.tmpVbUsdG/dedicated_log_torch_trace_65oh3qj_.log/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000))

## Next steps
1. I think for async TP `fused_scaled_matmul_reduce_scatter` we may only need `scatter_dim_after_maybe_reshape` and not `orig_scatter_dim` after all. I can confirm this and refactor if it is the case.
2. This op is specifically designed for async TP, and many of the arguments don't make sense for a user trying to use this as a standalone op. IMO we should have separate standalone custom op without all the extra function args and internal logic that doesn't apply to non-async TP cases.
3. In a follow up PR I want to add shape annotations to each line (e.g. `# (B, T, H)` etc) to make this easier to debug in the future.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153595
Approved by: https://github.com/fegin
2025-05-15 21:44:57 +00:00
ea17cd067d Add vec_reduce_all specialization for std::plus on AArch64 (#152388)
AArch64 has an instruction for this.

Differential Revision: [D73817183](https://our.internmc.facebook.com/intern/diff/D73817183/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152388
Approved by: https://github.com/Skylion007
ghstack dependencies: #152365, #152366
2025-05-15 21:26:18 +00:00
b972435158 vec::map: directly process reduced-precision floats when reasonable (#152366)
The immediate motivation is to make map support match
ExecuTorch so we can delete ExecuTorch-specific mapping functions, but
this should also straightforwardly improve performance.

Testing: there is existing coverage for this in
vec_test_all_types.cpp. Verified that it really does cover the newly
enabled "don't convert through float" paths by temporarily adding a
TORCH_INTERNAL_ASSERT(false).

Differential Revision: [D73802126](https://our.internmc.facebook.com/intern/diff/D73802126/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152366
Approved by: https://github.com/malfet
ghstack dependencies: #152365
2025-05-15 21:26:18 +00:00
e4adf5df39 [ROCm] cpp_extension allow user to override default flags (#152432)
We need -fgpu-rdc for projects such as DeepEP + rocSHMEM. The default of -no-gpu-rdc doesn't work for such cases.

As per https://github.com/pytorch/pytorch/pull/152432#issuecomment-2840899088:
"rocshmem shares the same global variable in different files, as deepEP uses CUDAExtention to build the project 65e2a700f0/setup.py (L51) and depends on rocshmem, this -fgpu-rdc is needed. The current logic in Pytorch prevents users from overriding this flag."

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152432
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
2025-05-15 21:06:18 +00:00
b8fad785d5 Change trigger for autoformat, use --all-files (#153289)
Change trigger for auto format to be pull_request b/c the reusable action used gets the pr number from the pull_request event context, but only run it if ciflow/autoformat is attached to the PR.  Tested this on a different PR, and it seems to be working

Changed tag name because ciflow prefixed labels have special handling

Also change to run on all files so it will mimic the normal CI lintrunner call, and because lintrunner, either by itself or using -m mergebase can miss some things.  Idk if it would miss for format, but it does for checking lint.  Format seems to take shorter than normal lint.  I don't know if the comment about making suggestions on non edited file changes is a concern.  I didn't really test this part

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153289
Approved by: https://github.com/atalman, https://github.com/malfet
2025-05-15 20:38:33 +00:00
90deff6d59 Refactor tests in test_max_autotune into a few separate test cases. (#153486)
Summary: To support running a subset of these tests with the remote autotuning utilities, I've split out some of the tests into separate classes so that I can derive from the "main" TestMaxAutotune class when creating new tests for remote. I'm not 100% sure what some of these tests do, so please suggest if another grouping / naming might make more sense. The remaining tests in TestMaxAutotune all smelled relevant to me.

Test Plan: existing unit tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153486
Approved by: https://github.com/eellison
2025-05-15 20:35:22 +00:00
a2e2f908fd add is_vec_specialized_for (#152365)
Let people detect at compile time whether Vectorized is specialized for a given type. See vec_base.h.

Differential Revision: [D73802129](https://our.internmc.facebook.com/intern/diff/D73802129/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152365
Approved by: https://github.com/jgong5, https://github.com/malfet
2025-05-15 20:21:48 +00:00
ae0e8f0c73 Revert "Delete TorchScript based Android demo app and point to ExecuTorch (#153633)"
This reverts commit b22f01fcb9d69bb7d77e08d69004c7265ef7fa4a.

Reverted https://github.com/pytorch/pytorch/pull/153633 on behalf of https://github.com/malfet due to But libtorch build regressions are real, fbjni is still used for C++ builds ([comment](https://github.com/pytorch/pytorch/pull/153633#issuecomment-2884951805))
2025-05-15 20:16:05 +00:00
b03e4f53d2 [Monitoring] enable windows monitoring test (#153453)
enable the utilization for win tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153453
Approved by: https://github.com/huydhn
2025-05-15 20:03:07 +00:00
f7ecc091a0 c10d/TCPStore: better logs on remote shutdown (#153586)
This makes it more obvious what's going on when TCPStore shuts down while waiting on a remote key and also shows the remote address.

Test plan:

```
[W514 18:33:36.536327028 TCPStore.cpp:138] [c10d] recvValueWithTimeout failed on SocketImpl(fd=3, addr=[localhost]:34658, remote=[localhost]:1234): Failed to recv, got 0 bytes. Connection was likely closed. Did the remote server shutdown or crash?
```

```py
import os
rank = int(os.environ["RANK"])

import time
from torch import distributed as dist

store = dist.TCPStore(
    host_name="localhost",
    port=1234,
    is_master=(rank == 0),
    wait_for_workers=False,
)

time.sleep(1)

print("starting")

if rank != 0:
    store.get("foo")
else:
    time.sleep(1)

print("done")
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153586
Approved by: https://github.com/XilunWu
2025-05-15 20:02:51 +00:00
064f4c18f9 [Monitoring] Enable perf tests (#153452)
Enable monitoring for more perf tests, currently for perf, we collect usage data every 4 seconds and aggregate every 15 seconds.

Can reduce the number down if the monitoring does not affect the perf testx
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153452
Approved by: https://github.com/Skylion007, https://github.com/huydhn
2025-05-15 19:19:19 +00:00
a4c828199e [BE] Add __all__ to torch/nn/functional.pyi and torch/return_types.pyi (#150729)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150729
Approved by: https://github.com/aorenste
2025-05-15 19:01:57 +00:00
b22f01fcb9 Delete TorchScript based Android demo app and point to ExecuTorch (#153633)
Delete TorchScript demo app and point people to ExecuTorch demo app.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153633
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/atalman, https://github.com/janeyx99, https://github.com/seemethere
2025-05-15 18:43:59 +00:00
00e5cb3db3 [ez][trymerge] Edit revert message for reverted ghstack PRs (#153573)
Change comment about successful revert so it also contains info about the original PR that got the comment (if it is a ghstacked PR)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153573
Approved by: https://github.com/atalman, https://github.com/malfet
2025-05-15 18:23:20 +00:00
480ae2dab8 Add needs_contiguous_strides to more collective ops (#153523)
Differential Revision: D74705770

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153523
Approved by: https://github.com/fmassa
2025-05-15 17:27:37 +00:00
cfee9046b6 cpu: enable gemm-bf16f32 for SDPA BF16 (#140159)
This PR enables SDPA BF16:  gemm:bf16f32 for aarch64.  This will enable faster inference for models with attention layers  for autocast mode (bf16).

Benchmark results from  [PyTorch CI HUD - branch](https://hud.pytorch.org/benchmark/huggingface/inductor_no_cudagraphs?dashboard=torchinductor&startTime=Fri%2C%2028%20Mar%202025%2021%3A26%3A20%20GMT&stopTime=Fri%2C%2004%20Apr%202025%2020%3A26%3A20%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cpu%20(aarch64)&lBranch=adi/gemm_bf16f32&lCommit=d5aeab452e4b1f0580a4636b15a604c77a02c57b&rBranch=main&rCommit=bc72420bcb37390af3fced885e019903e6e425bd)
Overall Geometric mean speedup in HUD dashboard  : for Huggingface: `[0.48x → 0.58x]` and for Blueberries: `[0.88x → 1.13x]`

Benchmark numbers for `torch.nn.functional.scaled_dot_product_attention`on Neoverse™ V1.

`batch_size = 1, num_attention_heads = 64, sequence_length = 512, attention_head_size = 128`
 `threads=16`
<img width="319" alt="Screenshot 2024-12-20 at 16 23 22" src="https://github.com/user-attachments/assets/c863f97d-0761-4fb8-aa6c-fc67b22ac3f9" />

Script to benchmark & profile SDPA:

    import torch
    import torch.nn as nn
    import time
    import numpy as np
    from torch.profiler import profile, record_function, ProfilerActivity
    class SimpleAttentionModel(nn.Module):
        def __init__(self, query, key, value):
            super(SimpleAttentionModel, self).__init__()
            self.query = query
            self.key = key
            self.value = value

        def forward(self, attn_mask=None):
            torch.nn.functional.scaled_dot_product_attention(
                        self.query,
                        self.key,
                        self.value,
                        attn_mask=attn_mask)

    #batch_size = 1, num_attention_heads = 64, sequence_length = 512, hidden_size = 128
    def bench_sdpa(batch_size = 1, num_attention_heads = 64, sequence_length = 512, query_sequence_length = 128 , hidden_size=128, precision=torch.float32):
        with torch.no_grad():
            attention_head_size = int(hidden_size / num_attention_heads)
            query = torch.rand(size=(batch_size, num_attention_heads, query_sequence_length, attention_head_size), dtype=precision)
            key = torch.rand(size=(batch_size, num_attention_heads, sequence_length, attention_head_size), dtype=precision)
            value = torch.rand(size=(batch_size, num_attention_heads, sequence_length, attention_head_size), dtype=precision)

            model = SimpleAttentionModel(query, key, value)
            model.eval()
            for _ in range(10):
                model()
            times = []
            n_iters = 100
            for _ in range(n_iters):
                s = time.time_ns()
                model()
                times.append((time.time_ns() - s) / 1e3)
            min_times = np.min(times)
            mean_times = np.mean(times)
            print(f"Min Times = {min_times} us")
            print(f"Mean Times = {mean_times} us")
            print("Times = ", times)

    print("BF16 mode:")
    with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
        with record_function("model_inference"):
            bench_sdpa(precision=torch.bfloat16)
    profile_data = prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total")
    print(profile_data)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140159
Approved by: https://github.com/jgong5, https://github.com/malfet, https://github.com/nikhil-arm, https://github.com/leslie-fang-intel, https://github.com/CaoE, https://github.com/cfRod, https://github.com/fadara01
2025-05-15 17:21:18 +00:00
236b08cbf8 Revert "[ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)"
This reverts commit 4863e5c843722eb2a34fb0ca1d518a33431a38c0.

Reverted https://github.com/pytorch/pytorch/pull/153300 on behalf of https://github.com/malfet due to Looks like it breaks rocm, see fa8543454a/1 ([comment](https://github.com/pytorch/pytorch/pull/153300#issuecomment-2884489459))
2025-05-15 16:58:52 +00:00
2327c9eedc Revert "[ca][dtensor] run real PG dtensor tests under CA (#152689)"
This reverts commit b297e01f4b1f43ffd1769313f077a2a68928f012.

Reverted https://github.com/pytorch/pytorch/pull/152689 on behalf of https://github.com/malfet due to Looks like it breaks rocm, see fa8543454a/1 ([comment](https://github.com/pytorch/pytorch/pull/153300#issuecomment-2884489459))
2025-05-15 16:58:51 +00:00
db26aeaec2 [MPSInductor] Support numpy scalars handling (#153598)
By default, numpy computes results in float64 format, but when passed as an argument to MPS function, must be implicitly converted to float32, which naturally occurs in some networks, for example in speech_transformer

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153598
Approved by: https://github.com/cyyever, https://github.com/dcci
ghstack dependencies: #153582
2025-05-15 16:48:25 +00:00
0cb48633d9 [ez][CI] Add linux aarch64 to upload test stats, change format of trigger for upload test stats (#153505)
Change from inline list to yml list
Add linux aarch64 for list of triggering workflows
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153505
Approved by: https://github.com/Skylion007
2025-05-15 15:33:59 +00:00
fa8543454a [dynamo][torch-function] Prevent unnecessary __torch_function__ tracing (#153551)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153551
Approved by: https://github.com/mlazos
2025-05-15 14:06:17 +00:00
4f4ecc583e [BE]: Enable RUFF TRY400 rule - log.exception (#153473)
Change logging.error to logging.exception to log additional information when relevant.  A few places have slipped in logging.errors in try except since I last did a clean up here and the rule is stabilized so I am enabling it codebase wide. I have NOQA'd much of our custom exception stack trace handling for RPC calls and distributed and tried to a fix a few errors based on whether we immediately reraised it or if we didn't print any exception handling where it could be useful.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153473
Approved by: https://github.com/albanD, https://github.com/cyyever
2025-05-15 13:36:59 +00:00
7482eb217c [Inductor-CPU] Faster int8 WoQ GEMM for small M with explicit prefetching and different outer loops (#149373)
### Summary

Fixes #148494

Explicitly prefetch the cache lines of the next `B` block to accelerate int8 WoQ (BF16 activation, int8 statically quantized weights) GEMM for small `M` dimension.

Some of this code (outer loops of the GEMM) is being ported over from Intel Extension for PyTorch. The macro-kernel* and the micro-kernel* are essentially the same, but optionally prefetch a block of B. Templatization is being used to prevent branching causing a slowdown due to unnecessary prefetching.

\* - in [BLIS](https://dl.acm.org/doi/10.1145/2764454) parlance

### Performance data with BS 1

Machine: 32 cores of one socket of a Intel Xeon SP Gen 5 machine

| Model | input tokens | output tokens | next-token latency before this PR | Next-token latency after this change | Speedup |
|-----------|-------------|-----------------|--------------------------------------|------------------------------------------|-----------|
|GPT-J | 128 | 128 | 42 ms | 38 ms | 9.52 % |
| GPT-J | 1024 | 1024 | 48 ms | 45 ms | 6.25 % |
|LLaMA 3.1 8B Instruct | 128 | 128 | 52 ms | 47 ms|  9.61% |
|LLaMA 3.1 8B Instruct | 1024 | 1024 | 57 ms | 53 ms|  7.01% |

While the input shapes of GEMMs corresponding to linear for next-token computation remain the same in case of different number of input & output tokens, the difference in next-token latency is due to attention for those cases

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149373
Approved by: https://github.com/leslie-fang-intel, https://github.com/Xia-Weiwen

Co-authored-by: Xia Weiwen <xia.weiwen@hotmail.com>
2025-05-15 11:55:58 +00:00
cyy
e5e06d9cab [submodule] Update kleidiai to v1.8.0 (#153592)
And cleans up some CMake instructions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153592
Approved by: https://github.com/malfet
2025-05-15 10:14:05 +00:00
22b124335e [BE] Update .pyi stub template to use Generic TypeAlias (PEP 585) and Union Type (PEP 604) (#150728)
https://github.com/pytorch/pytorch/pull/129001#discussion_r1645126801 is the motivation for the whole stack of PRs. In `torch/__init__.py`, `torch._C.Type` shadows `from typing import Type`, and there is no type stub for `torch._C.Type` in `torch/_C/__init__.pyi`. So we need to use `from typing import Type as _Type`. After enabling [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585) in the `.pyi` type stub files, we can use `type` instead of `typing.Type` or `from typing import Type as _Type`.

------

- [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585): e.g. `typing.List[T] -> list[T]`, `typing.Dict[KT, VT] -> dict[KT, VT]`, `typing.Type[T] -> type[T]`.
- [Union Type (PEP 604)](https://peps.python.org/pep-0604): e.g. `Union[X, Y] -> X | Y`, `Optional[X] -> X | None`, `Optional[Union[X, Y]] -> X | Y | None`.

Note that in `.pyi` stub files, we do not need `from __future__ import annotations`. So this PR does not violate issue #117449:

- #117449

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150728
Approved by: https://github.com/cyyever, https://github.com/aorenste
ghstack dependencies: #150726, #150727
2025-05-15 09:36:42 +00:00
f7a5aa1d8d [torchgen] Refactor and simplify gen_pyi.py to use Generic TypeAlias (PEP 585) and Union Type (PEP 604) (#150727)
https://github.com/pytorch/pytorch/pull/129001#discussion_r1645126801 is the motivation for the whole stack of PRs. In `torch/__init__.py`, `torch._C.Type` shadows `from typing import Type`, and there is no type stub for `torch._C.Type` in `torch/_C/__init__.pyi`. So we need to use `from typing import Type as _Type`. After enabling [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585) in the `.pyi` type stub files, we can use `type` instead of `typing.Type` or `from typing import Type as _Type`.

------

- [Generic TypeAlias (PEP 585)](https://peps.python.org/pep-0585): e.g. `typing.List[T] -> list[T]`, `typing.Dict[KT, VT] -> dict[KT, VT]`, `typing.Type[T] -> type[T]`.
- [Union Type (PEP 604)](https://peps.python.org/pep-0604): e.g. `Union[X, Y] -> X | Y`, `Optional[X] -> X | None`, `Optional[Union[X, Y]] -> X | Y | None`.

Note that in `.pyi` stub files, we do not need `from __future__ import annotations`. So this PR does not violate issue #117449:

- #117449

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150727
Approved by: https://github.com/aorenste
ghstack dependencies: #150726
2025-05-15 09:36:42 +00:00
129a2976a8 [ROCm] Improvements to non-vectorized elementwise kernels (#153184)
* Unroll loops manually to hide memory access latency

Co-authors: @akadutta @amd-hhashemi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153184
Approved by: https://github.com/jeffdaily
2025-05-15 09:14:43 +00:00
6e107899da [Torch] Fix crash when comparing fp8 tensors that have more than 1 dimension (#153508)
Summary: `torch.nonzero` returns as many items as the number of dimensions, so we shouldn't expect a single element for the indices.

Test Plan: CI

Differential Revision: D74539233

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153508
Approved by: https://github.com/exclamaforte
2025-05-15 08:41:46 +00:00
b297e01f4b [ca][dtensor] run real PG dtensor tests under CA (#152689)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152689
Approved by: https://github.com/bdhirsh
ghstack dependencies: #153300
2025-05-15 08:10:35 +00:00
4863e5c843 [ca][dynamo] always run eager checkpoint region's recomputation in eager (#153300)
I slap disable on the recomputation hook, otherwise the partitioner may save less/more activations and mismatch with the expected eager count in checkpoint. See code comment `Note: [compiled autograd and checkpoint unpack hook]`.

This fixes all non-nested checkpointing tests. I also wrap nested checkpointing tests, and a few of them still fail.

This also seems to fix all PYTORCH_TEST_WITH_DYNAMO checkpointing tests except for `TestAutograd.test_checkpointing_without_reentrant_custom_function_works`. For those tests, it looks like we fail to HOPify the checkpointed region and when the backward executes the unpack hooks, dynamo tried to trace them. This messed up the internal state tracking of checkpointing, some raising the _StopRecomputationError and others raising the same count mismatch error as CA.

FIXES https://github.com/pytorch/pytorch/issues/127115

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153300
Approved by: https://github.com/jansel
2025-05-15 08:10:35 +00:00
71027b13b2 Revert "[FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation (#153357)"
This reverts commit 881a598a1e38ef06d4f51d1e3fd8e359fed0c3a0.

Reverted https://github.com/pytorch/pytorch/pull/153357 on behalf of https://github.com/jeanschmidt due to Might have introduced regressions in rocm testing for main: https://github.com/pytorch/pytorch/actions/runs/15035410497/job/42257000513 feel free to re-merge if this was a mistake ([comment](https://github.com/pytorch/pytorch/pull/153357#issuecomment-2882915691))
2025-05-15 07:58:27 +00:00
004dad48f7 Allow to set custom PYTHONPATH for torch.inductor (#152832)
When using Bazel, it’s common to encounter issues like [this](https://github.com/bazelbuild/bazel/issues/14640) and [this](https://github.com/bazel-contrib/rules_python/issues/792) where the `PYTHONPATH` environment variable becomes too long and results in an error such as: `OSError: [Errno 7] Argument list too long` . To work around this, users often resort to custom logic to manipulate PYTHONPATH.

Currently, PyTorch Inductor constructs the PYTHONPATH for a subprocess using sys.path, which can lead to this issue in certain environments.

This PR introduces support for a new environment variable, `TORCH_CUSTOM_PYTHONPATH`, allowing users to override the default `PYTHONPATH` passed to the subprocess. This provides a clean way to avoid an exception when using PyTorch in Bazel.

Please let me know if I need to add some documentation to support this PR. I haven't found an open issue specific to this change but I'm confident that this change (or a similar one) would be appreciated by few.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152832
Approved by: https://github.com/masnesral
2025-05-15 06:35:41 +00:00
55784be01b [Quant][X86] add ops to compute uint8 pointwise add/add_relu (#152411)
**Summary**
This PR adds two new ops, `onednn.qadd.tensor` and `onednn.qadd_relu.tensor`, for int8 elementwise add, which accepts inputs on CPU device (instead of QuantizedCPU).
The new ops are implemented with AVX512 instructions and it provides similar or better performance, depending on shape, than its counterpart for QuantizedCPU device `quantized.add` and `quantized.add_relu`.
The new op supports output dtypes other than uint8 (fp32, fp16 and bf16 are supported).

**Test plan**
```
pytest test/quantization/core/test_quantized_op.py -k test_int8_add_onednn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152411
Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
2025-05-15 06:23:01 +00:00
a762dd1f67 [Memento] On-demand mode using without torch api (#153171)
Summary:
CUDA Post: https://fb.workplace.com/groups/ai.efficiency.tools.users/permalink/2020094788475989/

# Context
In this diff, we want to enable the on-demand mode of memory snapshot to allow user to trace any remote process via dyno command line.

# Design decision

**How do we send on-demand signal to remote process**
We leverage the dyno-Kineto approach.
Since dyno is running on all machine in Meta, it can send a request to the remote machine to start the Kineto.
Kineto will start another thread for memoryProfiler (https://fburl.com/code/dxsmmrok)

**why we use different approach as CUDA**

On CUDA side, we are using pybind to load torch Module and invoke the python api to start/stop the profiling. However, this requires us to compile the whole torch binary in the predictor which is not recommended by runtime(andruwang)

Thus, we decide to use the CPP api directly to avoid un-necessary dependency

**why the snapshot is saved as json string directly instead of pickle**
Pickle is primarily designed for use with Python and doesn't have well support in cpp. Also, it is hard for user to download the snapshot file and open locally.
Due to the dependency issue, it is hard to import the gzip/pickle library to decode the data. Thus, let's use JSON for now. I will work on the visualizer to fasten the render and support other format later.

**Plan**:
* Now, we will encoded file into gz for MTIA ondemand only and update the visualizer to support both type.
* Update auto-trace and CUDA side to encode in gzip as well
* Fully remove pickle dependency.

Test Plan:
# Remote cogwheel test
Servicelab: https://fburl.com/servicelab/pckux7a3
snapshot file manifold: https://fburl.com/manifold/fnotk18c
snapshot file in pastry: P1805522232

Visualization on D74399684
 {F1977786422}

# Local Predictor Test
url: https://fburl.com/pytorch_memory_visualizer/y06kskkm

 {F1977787329}

Differential Revision: D74179606

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153171
Approved by: https://github.com/sraikund16
2025-05-15 06:07:04 +00:00
181bfabb9e fix set_logs for a single child log file (#153580)
Tested via

```
+        import logging
+        torch._logging.set_logs(modules={"torch._functorch._aot_autograd.autograd_cache": logging.DEBUG})
```

```
python test/dynamo/test_aot_autograd_cache.py -k test_multi_graph_specialization
```
and verifying logs are printed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153580
Approved by: https://github.com/ColinPeppler
2025-05-15 05:58:45 +00:00
9839ec1383 [dynamo][compile-time] Cache method on load builtin (#153524)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153524
Approved by: https://github.com/StrongerXi, https://github.com/jansel
ghstack dependencies: #153522
2025-05-15 05:54:15 +00:00
b47be23461 [dynamo][compile-time] Faster inspect getattr_static for torch.Tensor (#153522)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153522
Approved by: https://github.com/StrongerXi, https://github.com/jansel
2025-05-15 05:54:15 +00:00
910d2f96af [cutlass backend] forward fix cutlass backend A100 test (#153428)
Forward fix of https://github.com/pytorch/pytorch/pull/153006, which broke a test.

In the long run, we should get rid of CUDATemplateCaller.category.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153428
Approved by: https://github.com/ColinPeppler
2025-05-15 05:45:38 +00:00
0ca91af6b8 Define USE_C10D_XCCL and USE_XCCL in pytorch (#147593)
### Motivation:

Add `USE_XCCL` and `USE_C10D_XCCL` to enable support of XCCL backend building in stock PyTorch, similar to `USE_NCCL` and `USE_C10D_NCCL`.
 By default, `USE_XCCL` is OFF and allowed set to ON explicitly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147593
Approved by: https://github.com/guangyey, https://github.com/malfet, https://github.com/albanD, https://github.com/cyyever
2025-05-15 05:39:00 +00:00
ebd3268538 Removed duplicate patterns from gitignore (#153515)
Removed duplicate patterns from gitignore. These patterns are duplicated verbatim on lines 148-169.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153515
Approved by: https://github.com/soulitzer
2025-05-15 05:38:42 +00:00
b992a665d1 Fix AsyncMM not compiled with SM90a issue (#153519)
The CMakeLists.txt is wrong and doesn't enable SM90a for AsyncMM.cu
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153519
Approved by: https://github.com/drisspg, https://github.com/ngimel, https://github.com/cyyever
2025-05-15 05:23:29 +00:00
d5ddc5ab20 [MPS] Fix float64 scalar tensor handling (#153582)
Current implementation causes silent correction problem with torch.compile when someone tries to `torch.compile` function where one of the arguments is say `np.exp(.3)`, which will be represented as torch.float64 scalar tensor

Add regssion test for this behavior
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153582
Approved by: https://github.com/dcci
2025-05-15 05:15:14 +00:00
3e8bda4ad5 [pytorch][triton] flex attention fwd kernel with TMA loads (#151923) (#152460)
Summary:

Device side TMA for flex_attention fwd kernel, Q K V tensors

Test Plan:
Unit test:
```
buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:flex_attention -- test_tma_with_customer_kernel_options
```
https://www.internalfb.com/intern/testinfra/testrun/14355223891618726

Differential Revision: D71082691

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152460
Approved by: https://github.com/drisspg
2025-05-15 04:49:32 +00:00
756fd80734 [BE] Improve the typing related to model input argument of torch.compile() (#153559)
Summary: Match the `overload` typing with the original typing in function definition and adjust the corresponding comments.

Test Plan: contbuild & OSS CI

Differential Revision: D74746243

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153559
Approved by: https://github.com/Skylion007
2025-05-15 04:49:26 +00:00
d2f6c6df1d unbreak fb:operator_benchmark_test (#152049)
Summary: unbreak fb:operator_benchmark_test

Test Plan: works on my machine

Differential Revision: D73540912

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152049
Approved by: https://github.com/hl475
2025-05-15 03:38:48 +00:00
014726d9d3 [torchgen] Refactor torchgen.utils.FileManager to accept pathlib.Path (#150726)
This PR allows `FileManager` to accept `pathlib.Path` as arguments while keeping the original `str` path support.

This allows us to simplify the code such as:

1. `os.path.join(..., ...)` with `Path.__floordiv__(..., ...)`.

95a5958db4/torchgen/utils.py (L155)

95a5958db4/torchgen/utils.py (L176)

2. `os.path.basename(...)` with `Path(...).name`.
 95a5958db4/torchgen/utils.py (L161)

3. Manual file extension split with `Path(...).with_stem(new_stem)`

95a5958db4/torchgen/utils.py (L241-L256)

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150726
Approved by: https://github.com/aorenste
2025-05-15 02:52:24 +00:00
881a598a1e [FlexAttention] Enforce Q,K,V memory layouts for fp8 flex attention to avoid perf degradation (#153357)
Fixes #147336

## Context

NCU analysis of the fp8 flex attention perf issue in #147336 showed an unexpected increase in shared memory access bank conflicts when loading the V tensor from HBM to SRAM.

Bringing this to the attention of triton developer @davidberard98 he identified the memory layout of the tensor in HBM to be causing non-pipelined loads into SRAM, causing the slowdown.

To summarize:

In flex attention when performing the FP8 GEMM `softmax_scores @ V` the right operand V must be in column-major memory layout. However, the `tl.load` of V blocks from HBM to SRAM cannot be pipelined if the V tensor isn't column-major in HBM already, leading to substantial performance degradation.

This is because triton does not perform async copies with the `cp.async` PTX instruction if the number of contiguous bytes is less than 4 (see [here](81f93f2c8e/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp (L403))).

i.e., when loading 4 bytes of contiguous data from a tensor stored in row-major in HBM, we have to perform 4 separate non-contiguous writes to SRAM to place those bytes in their new location in the col-major layout in SRAM. Thus the load is not a candidate for pipelining w/ cp.async and just moves data to registers then performs a series of single byte stores.

## Fix summary
- To fix this, we should enforce memory layouts for Q, K, V in FlexAttention when fp8 is being used, to ensure they each exist in HBM in the necessary memory layout to facilitate pipelined loads into SRAM ahead of the FP8 GEMMs

## Benchmarks
Rerunning the repro we see fp8 runtime is reduced from 120% of bf16 to 76% of bf16 runtime.

Before fix:

```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 19:07:33,402 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 19:07:35,885 - flex_bench - INFO - bf16: 424.87228804347734 us
2025-05-11 19:07:35,893 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 19:07:37,319 - flex_bench - INFO - fp8e4m3: 515.714000000001 us
```

After fix:
```
(flex) [danvm@devgpu007.eag6 ~/ml-perf-tools/flex_attention (main)]$ rm -rf /tmp/torchinductor_${USER}; python profile_flex.py --bf16 --fp8
2025-05-11 17:34:38,223 - flex_bench - INFO - Running benchmark: bf16
2025-05-11 17:34:41,157 - flex_bench - INFO - bf16: 423.4662032967036 us
2025-05-11 17:34:41,167 - flex_bench - INFO - Running benchmark: fp8e4m3
2025-05-11 17:34:42,917 - flex_bench - INFO - fp8e4m3: 326.3694803493453 us
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153357
Approved by: https://github.com/ngimel, https://github.com/davidberard98
2025-05-15 02:41:38 +00:00
eaf2dee10e don't run triton mm for k<32 (#153550)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153550
Approved by: https://github.com/suo

Co-authored-by: Natalia Gimelshein <ngimel@meta.com>
2025-05-15 02:36:44 +00:00
725bbb6b5f [inductor][dynamo] Include operator name in size/stride/alignment assertion (#152353)
Fixes #151930

This PR updates the `assert_size_stride` and `assert_alignment` functions in [guards.cpp](https://github.com/pytorch/pytorch/blob/main/torch/csrc/dynamo/guards.cpp) to accept an optional `op_name` argument and includes it in the error messages.

The corresponding type stubs in [guards.pyi](https://github.com/pytorch/pytorch/blob/main/torch/_C/_dynamo/guards.pyi) are updated to match the new function arg.

In [inductor/ir.py](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py) extracts the operator name from the FX graph and passes it into the `codegen_size_asserts` and `codegen_alignment_asserts` functions, so that generated assertions in Triton code include the op name for better debugging.

Added unit tests inside [test_torchinductor.py](https://github.com/pytorch/pytorch/blob/main/test/inductor/test_torchinductor.py).
- Verified both successful and failing assertion cases include the operator name.
- Verified that generated Triton code contains the op name inside the asserts.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152353
Approved by: https://github.com/jansel
2025-05-15 02:33:57 +00:00
f5e0806f34 [cutlass backend] Add back descriptive names for epilogue fusion (#153405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153405
Approved by: https://github.com/mlazos
2025-05-15 01:47:52 +00:00
82dc3457e0 Add load_state_dict hint doc about invoke order work with lr_scheduler (#149942)
Fixes #119168

## Test Result

![image](https://github.com/user-attachments/assets/edb8124c-f103-475a-b903-20fbc71fdea6)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149942
Approved by: https://github.com/janeyx99

Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
2025-05-15 01:07:36 +00:00
cyy
781ba0ac9d Update CMake to 3.27 in Windows CI (#153380)
Before it's possible to use enable newer CMake.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153380
Approved by: https://github.com/albanD
2025-05-15 00:19:32 +00:00
c2bc7e2827 API change for new enum in cusparseltsplitkmode-t for cusparseLT 0.7.0+ (#150536)
Changing the bool to int to express split_k_mode. Before 0.7.0 we only have 2 cusparseLtSplitKMode_t enum values ONE_KERNEL and TWO_KERNELS so a boolean is enough but since 0.7.0 there are more.

For Blackwell, there has to be minor change to parameter split_k_one_kernel (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp#L103), since there are new values introduced to enum [cusparseLtSplitKMode_t](https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t) and a bool type is not enough for it (would have to be replaced with integer) https://docs.nvidia.com/cuda/cusparselt/types.html#cusparseltsplitkmode-t

Error we see without the change
```
RuntimeError: CUDA error: invalid value when calling `cusparseLtMatmulAlgSetAttribute( &handle, &alg_sel, CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode))`

To execute this test, run the following from the base repo dir:
    python test/test_sparse_semi_structured.py TestSparseSemiStructuredCUSPARSELTCUDA.test_csrc_cslt_sparse_mm_search_cuda_int8
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150536
Approved by: https://github.com/jcaip, https://github.com/atalman
2025-05-14 23:36:53 +00:00
72fee137dd [ROCm] Maxpool forward NHWC Perf Improvement targeting Resnet scenarios (#151727)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151727
Approved by: https://github.com/seemethere

Co-authored-by: Eli Uriegas <1700823+seemethere@users.noreply.github.com>
2025-05-14 22:34:55 +00:00
e0dece510b [Ez][BE]: Remove accidental classvar (#153540)
Untyped variables become ClassVar in dataclasses, this type alias should just be a type alias; no need for it to eb a classvar.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153540
Approved by: https://github.com/albanD, https://github.com/aorenste
2025-05-14 21:55:56 +00:00
7412b33e91 [inductor] Use get to avoid possible keyerror at the end of precompilation (#153417)
Shameful admission: I have encountered this error 1-2 times, but don't have a repro.

torch/_inductor/select_algorithm.py", line 2022, in wait_on_futures
    elapsed_times[future],
    ~~~~~~~~~~~~~^^^^^^^^
torch._inductor.exc.InductorError: KeyError: <Future at 0x7fc4e394fb90 state=finished returned tuple>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153417
Approved by: https://github.com/Skylion007, https://github.com/ColinPeppler
2025-05-14 21:49:43 +00:00
f2e8e41855 [Easy][Inductor] Adds safety checks in get_estimated_runtime (#152821)
This PR adds checks on `gpu_memory_bandwidth` and `gpu_flops` in `get_estimated_runtime`. This will prevent division by zero and other potential incorrect values:
9210a98b92/torch/_inductor/scheduler.py (L864-L865)

9210a98b92/torch/_inductor/scheduler.py (L874)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152821
Approved by: https://github.com/eellison, https://github.com/jansel
2025-05-14 21:46:59 +00:00
f887bfffda Fix typo (#153561)
Fix typo from #153386

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153561
Approved by: https://github.com/albanD
2025-05-14 21:38:51 +00:00
03d01860fd [dynamo][compile-time] Compute logging related flags once (#153426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153426
Approved by: https://github.com/jansel
2025-05-14 21:19:06 +00:00
1bd6bc7190 [BE]: Enable ruff YTT linter for Python version checks (#153547)
Adds ruff YTT checks to help future proof version checks and follow best practices here. Also makes it easier for static linters like mypy to detect python version branching.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153547
Approved by: https://github.com/albanD
2025-05-14 21:09:16 +00:00
f363a3f51a Revert "[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 (#149282)"
This reverts commit 9386701b51aadce951bf38daf497b0257a3f2211.

Reverted https://github.com/pytorch/pytorch/pull/149282 on behalf of https://github.com/jeanschmidt due to Breaking internal builds, see [D74729259](https://www.internalfb.com/diff/D74729259). @drisspg may you help out the author have their PR merged? ([comment](https://github.com/pytorch/pytorch/pull/149282#issuecomment-2881546951))
2025-05-14 20:53:49 +00:00
c92ea3bc98 [BE] Upgrade XPU support package to 2025.1 in CICD (#151899)
Address #151097. Including below changes,

- Add XPU support package 2025.1 build and test in CI for both Linux and Windows
- Keep XPU support package 2025.0 build in CI to ensure no break issue until PyTorch 2.8 release
- Upgrade XPU support package from 2025.0 to 2025.1 in CD for both Linux and Windows
- Enable XCCL in Linux CD wheel and oneMKL integration in both both Linux and Windows
- Update XPU runtime pypi packages of CD wheels
- Remove deprecated support package version docker image build
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151899
Approved by: https://github.com/EikanWang, https://github.com/atalman
2025-05-14 20:21:09 +00:00
5e6e52e7c9 [JIT] add GRAPH_DEBUG for setGraphExecutorOptimize (#153549)
Summary: Optionally log when setGraphExecutorOptimize is called, so we can get insight into the GraphExecutor behavior.

Differential Revision: D74692508

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153549
Approved by: https://github.com/PaulZhang12, https://github.com/SamGinzburg
2025-05-14 20:07:25 +00:00
dda2c7c8fc Pass inductor config for static cuda launcher to workers (#153382)
Async compile workers don't respect inductor configs generally that get changed in the middle of execution because they warm up early. StaticCudaLauncher is especially susceptible to this because it affects triton compilation without being part of the inductor meta. So we'll pass it in via extra configs on each worker run.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153382
Approved by: https://github.com/masnesral, https://github.com/jansel
2025-05-14 20:01:32 +00:00
6a28cc826f Add TEST_HPU flag to set device type (#153461)
MOTIVATION
This PR includes a minor change to check for TEST_HPU flag as well before falling back to CPU. Without this flag, some tests were falling back to CPU causing them to fail.
Please refer to this RFC as well: https://github.com/pytorch/rfcs/pull/66

CHANGES
add TEST_HPU flag to some of the conditions checking the environment
use DEVICE_COUNT variable instead of torch.accelerator.device_count() API since the later is not supported on out-of-tree devices like Intel Gaudi.
@ankurneog , @EikanWang , @cyyever , @guangyey

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153461
Approved by: https://github.com/EikanWang, https://github.com/cyyever, https://github.com/albanD
2025-05-14 19:31:40 +00:00
a54bf43baa Fix support of MixtureSameFamily [bugfix]. (#151317)
Fixes https://github.com/pyro-ppl/pyro/issues/3419 which is actually a `torch` bug that can be replicated by the below code:

```
from torch import rand
from torch.distributions import MixtureSameFamily, Categorical, Binomial

max_count = 20
probs = rand(10, 5)
binom_probs = rand(10, 5)

d = MixtureSameFamily(Categorical(probs=probs), Binomial(max_count, binom_probs))
d.log_prob(d.sample())
```

which results in:

```
Traceback (most recent call last):
  File "test.py", line 11, in <module>
    d.log_prob(d.sample())
  File "pytorch\torch\distributions\mixture_same_family.py", line 168, in log_prob
    self._validate_sample(x)
  File "pytorch\torch\distributions\distribution.py", line 315, in _validate_sample
    valid = support.check(value)
            ^^^^^^^^^^^^^^^^^^^^
  File "pytorch\torch\distributions\constraints.py", line 307, in check
    (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
                                                      ^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 1
```

### Fix explanation (only for cases when the component distribution contains parameters with batch dimenisons)

- The failure is due to sample validation taking place before padding in `MixtureSameFamily.log_prob`, and hence the fix is to pad before doing sample validation.
- The fix itself does not alter the calculations at all. It only affects the sample validation process.
- The failure does not occur with the component distribution set to the `Normal` distribution, as its validation is not defined elementwise (the validation itself is elementwise).
- I've split the `test_mixture_same_family_log_prob` test into two tests based on the `Normal` and `Binomial` distributions.
- Initially, the `Binomial` version of the test did not fail, but this was due to the component distribution having equal batch dimensions of (5, 5) so I changed it to (10, 5).

### Updated fix explanation (for all cases)

- The previous fix caused a bug in sample shape validation (which is done correctly) due to the padding taking place before the sample validation.
- The updated fix corrects the support to reflect the fact that the support of `MixtureSameFamily` is equal to the support of its components distribution with the first event dimension removed.
- This issue was already anticipated in the [code](331423e5c2/torch/distributions/mixture_same_family.py (L127)).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151317
Approved by: https://github.com/albanD, https://github.com/fritzo
2025-05-14 19:24:36 +00:00
clr
534b66fe30 torch.compile: Remove reference to the unused dynamo_config.dynamic_shapes from (#153297)
tests

This config option is not set anywhere, and does nothing, so this should cause
no changes to tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153297
Approved by: https://github.com/Skylion007
2025-05-14 19:02:51 +00:00
bf0fe4f828 Revert "[CUDA][CUDNN] Dispatch to cuDNN for non-batch-splittable 64-bit NCHW convolutions (#153101)"
This reverts commit ced90d23d3dfff42379fa032fe6a83b764d12e9f.

Reverted https://github.com/pytorch/pytorch/pull/153101 on behalf of https://github.com/jeanschmidt due to Seems to have introduced breakages on main, tentative revert: https://github.com/pytorch/pytorch/actions/runs/15024667248/job/42224521705 ([comment](https://github.com/pytorch/pytorch/pull/153101#issuecomment-2881208171))
2025-05-14 18:52:07 +00:00
8749fe8439 [CI][MPS] Speedup test_large_bmm (#153562)
By computing matmuls of only one random non-zero batch on CPU

This reduces test runtime from 11 minutes to 14 sec
```
 % python3 test/test_mps.py -v -k test_large_bmm_
test_large_bmm_bfloat16 (__main__.TestMPS.test_large_bmm_bfloat16) ... ok
test_large_bmm_float16 (__main__.TestMPS.test_large_bmm_float16) ... ok

----------------------------------------------------------------------
Ran 2 tests in 27.495s

```

TODO: Compute it over two slices when https://github.com/pytorch/pytorch/issues/153560 is fixed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153562
Approved by: https://github.com/Skylion007, https://github.com/clee2000
2025-05-14 18:49:42 +00:00
47d6feff7c [export] Support no inputs in unflattened module (#153474)
Encountered in this diff D74589491
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153474
Approved by: https://github.com/avikchaudhuri
2025-05-14 18:45:47 +00:00
6ef1cbc191 Revert "[ROCm] Maxpool forward NHWC Perf Improvement targeting Resnet scenarios (#151727)"
This reverts commit e6a90672601ad3d636145dd8a68952281a6d1199.

Reverted https://github.com/pytorch/pytorch/pull/151727 on behalf of https://github.com/jeanschmidt due to Seems to be breaking internal builds, @seemethere may you help the author? [D74729252](https://www.internalfb.com/diff/D74729252) ([comment](https://github.com/pytorch/pytorch/pull/151727#issuecomment-2881122917))
2025-05-14 18:18:17 +00:00
533fc58453 [BE]: Fix typing None override other optimizers (#153386)
Follow up to #153367 to fix other instances of it throughout the codebase

Also fully type NamedOptimizer since we were so close

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153386
Approved by: https://github.com/tsunghsienlee, https://github.com/janeyx99, https://github.com/jansel, https://github.com/cyyever
2025-05-14 17:48:47 +00:00
741 changed files with 11139 additions and 12590 deletions

View File

@ -251,14 +251,6 @@ case "$tag" in
UCC_COMMIT=${_UCC_COMMIT}
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-xpu-2024.0-py3)
ANACONDA_PYTHON_VERSION=3.9
GCC_VERSION=11
VISION=yes
XPU_VERSION=0.5
NINJA_VERSION=1.9.0
TRITON=yes
;;
pytorch-linux-jammy-xpu-2025.0-py3)
ANACONDA_PYTHON_VERSION=3.9
GCC_VERSION=11
@ -267,6 +259,14 @@ case "$tag" in
NINJA_VERSION=1.9.0
TRITON=yes
;;
pytorch-linux-jammy-xpu-2025.1-py3)
ANACONDA_PYTHON_VERSION=3.9
GCC_VERSION=11
VISION=yes
XPU_VERSION=2025.1
NINJA_VERSION=1.9.0
TRITON=yes
;;
pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.9
GCC_VERSION=11

View File

@ -26,7 +26,7 @@ function install_ubuntu() {
wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
| gpg --dearmor > /usr/share/keyrings/oneapi-archive-keyring.gpg.gpg
echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg.gpg] \
https://apt.repos.intel.com/${XPU_REPO_NAME} all main" \
https://apt.repos.intel.com/oneapi all main" \
| tee /etc/apt/sources.list.d/oneAPI.list
# Update the packages list and repository index
@ -74,7 +74,7 @@ function install_rhel() {
tee > /etc/yum.repos.d/oneAPI.repo << EOF
[oneAPI]
name=Intel for Pytorch GPU dev repository
baseurl=https://yum.repos.intel.com/${XPU_REPO_NAME}
baseurl=https://yum.repos.intel.com/oneapi
enabled=1
gpgcheck=1
repo_gpgcheck=1
@ -118,7 +118,7 @@ function install_sles() {
https://repositories.intel.com/gpu/sles/${VERSION_SP}${XPU_DRIVER_VERSION}/unified/intel-gpu-${VERSION_SP}.repo
rpm --import https://repositories.intel.com/gpu/intel-graphics.key
# To add the online network network package repository for the Intel Support Packages
zypper addrepo https://yum.repos.intel.com/${XPU_REPO_NAME} oneAPI
zypper addrepo https://yum.repos.intel.com/oneapi oneAPI
rpm --import https://yum.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB
# The xpu-smi packages
@ -141,10 +141,10 @@ if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then
XPU_DRIVER_VERSION=""
fi
XPU_REPO_NAME="intel-for-pytorch-gpu-dev"
XPU_PACKAGES="intel-for-pytorch-gpu-dev-0.5 intel-pti-dev-0.9"
if [[ "$XPU_VERSION" == "2025.0" ]]; then
XPU_REPO_NAME="oneapi"
# Default use Intel® oneAPI Deep Learning Essentials 2025.0
if [[ "$XPU_VERSION" == "2025.1" ]]; then
XPU_PACKAGES="intel-deep-learning-essentials-2025.1"
else
XPU_PACKAGES="intel-deep-learning-essentials-2025.0"
fi

View File

@ -174,6 +174,6 @@ ENV XPU_DRIVER_TYPE ROLLING
RUN python3 -m pip install --upgrade pip && \
python3 -mpip install cmake==3.28.4
ADD ./common/install_xpu.sh install_xpu.sh
ENV XPU_VERSION 2025.0
ENV XPU_VERSION 2025.1
RUN bash ./install_xpu.sh && rm install_xpu.sh
RUN pushd /opt/_internal && tar -xJf static-libs-for-embedding-only.tar.xz && popd

View File

@ -20,7 +20,11 @@ fi
source /opt/intel/oneapi/compiler/latest/env/vars.sh
source /opt/intel/oneapi/pti/latest/env/vars.sh
source /opt/intel/oneapi/umf/latest/env/vars.sh
source /opt/intel/oneapi/ccl/latest/env/vars.sh
source /opt/intel/oneapi/mpi/latest/env/vars.sh
export USE_STATIC_MKL=1
export USE_ONEMKL=1
export USE_XCCL=1
WHEELHOUSE_DIR="wheelhousexpu"
LIBTORCH_HOUSE_DIR="libtorch_housexpu"

View File

@ -99,30 +99,6 @@ if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
export ACL_ROOT_DIR=/ComputeLibrary
fi
if [[ "$BUILD_ENVIRONMENT" == *libtorch* ]]; then
POSSIBLE_JAVA_HOMES=()
POSSIBLE_JAVA_HOMES+=(/usr/local)
POSSIBLE_JAVA_HOMES+=(/usr/lib/jvm/java-8-openjdk-amd64)
POSSIBLE_JAVA_HOMES+=(/Library/Java/JavaVirtualMachines/*.jdk/Contents/Home)
# Add the Windows-specific JNI
POSSIBLE_JAVA_HOMES+=("$PWD/.circleci/windows-jni/")
for JH in "${POSSIBLE_JAVA_HOMES[@]}" ; do
if [[ -e "$JH/include/jni.h" ]] ; then
# Skip if we're not on Windows but haven't found a JAVA_HOME
if [[ "$JH" == "$PWD/.circleci/windows-jni/" && "$OSTYPE" != "msys" ]] ; then
break
fi
echo "Found jni.h under $JH"
export JAVA_HOME="$JH"
export BUILD_JNI=ON
break
fi
done
if [ -z "$JAVA_HOME" ]; then
echo "Did not find jni.h"
fi
fi
# Use special scripts for Android builds
if [[ "${BUILD_ENVIRONMENT}" == *-android* ]]; then
export ANDROID_NDK=/opt/ndk

View File

@ -232,7 +232,8 @@ test_torchbench_smoketest() {
mkdir -p "$TEST_REPORTS_DIR"
local device=mps
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam pytorch_unet stable_diffusion_text_encoder moco speech_transformer)
local models=(hf_T5 llama BERT_pytorch dcgan hf_GPT2 yolov3 resnet152 sam pytorch_unet stable_diffusion_text_encoder speech_transformer Super_SloMo)
local hf_models=(GoogleFnet YituTechConvBert)
for backend in eager inductor; do
@ -253,6 +254,16 @@ test_torchbench_smoketest() {
--output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_accuracy.csv" || true
fi
done
for model in "${hf_models[@]}"; do
if [ "$backend" == "inductor" ]; then
PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \
--performance --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \
--output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_performance.csv" || true
PYTHONPATH="$(pwd)"/torchbench python benchmarks/dynamo/huggingface.py \
--accuracy --only "$model" --backend "$backend" --inference --devices "$device" "$dtype_arg" \
--output "$TEST_REPORTS_DIR/inductor_${backend}_torchbench_${dtype}_inference_${device}_accuracy.csv" || true
fi
done
done
for dtype in notset amp; do

View File

@ -37,6 +37,11 @@ call %INSTALLER_DIR%\activate_miniconda3.bat
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
:: Update CMake
call choco upgrade -y cmake --no-progress --installargs 'ADD_CMAKE_TO_PATH=System' --apply-install-arguments-to-dependencies --version=3.27.9
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
call pip install mkl-include==2021.4.0 mkl-devel==2021.4.0
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
@ -88,7 +93,7 @@ set PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH%
:cuda_build_end
set DISTUTILS_USE_SDK=1
set PATH=%TMP_DIR_WIN%\bin;%PATH%
set PATH=%TMP_DIR_WIN%\bin;C:\Program Files\CMake\bin;%PATH%
:: The latest Windows CUDA test is running on AWS G5 runner with A10G GPU
if "%TORCH_CUDA_ARCH_LIST%" == "" set TORCH_CUDA_ARCH_LIST=8.6

View File

@ -10,53 +10,23 @@ if not "%CUDA_VERSION%" == "xpu" (
set SRC_DIR=%NIGHTLIES_PYTORCH_ROOT%
if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build"
set XPU_INSTALL_MODE=%~1
if "%XPU_INSTALL_MODE%"=="" goto xpu_bundle_install_start
if "%XPU_INSTALL_MODE%"=="bundle" goto xpu_bundle_install_start
if "%XPU_INSTALL_MODE%"=="driver" goto xpu_driver_install_start
if "%XPU_INSTALL_MODE%"=="all" goto xpu_driver_install_start
:arg_error
echo Illegal XPU installation mode. The value can be "bundle"/"driver"/"all"
echo If keep the value as space, will use default "bundle" mode
exit /b 1
:xpu_driver_install_start
:: TODO Need more testing for driver installation
set XPU_DRIVER_LINK=https://downloadmirror.intel.com/830975/gfx_win_101.5972.exe
curl -o xpu_driver.exe --retry 3 --retry-all-errors -k %XPU_DRIVER_LINK%
echo "XPU Driver installing..."
start /wait "Intel XPU Driver Installer" "xpu_driver.exe"
if errorlevel 1 exit /b 1
del xpu_driver.exe
if "%XPU_INSTALL_MODE%"=="driver" goto xpu_install_end
:xpu_bundle_install_start
set XPU_BUNDLE_PARENT_DIR=C:\Program Files (x86)\Intel\oneAPI
set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d1a91e2-e8b8-40a5-8c7f-5db768a6a60c/w_intel-for-pytorch-gpu-dev_p_0.5.3.37_offline.exe
set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.intel-for-pytorch-gpu-dev.product
set XPU_BUNDLE_VERSION=0.5.3+31
set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d6d6c17-ca2d-4735-9331-99447e4a1280/intel-deep-learning-essentials-2025.0.1.28_offline.exe
set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.deep-learning-essentials.product
set XPU_BUNDLE_VERSION=2025.0.1+20
set XPU_BUNDLE_INSTALLED=0
set XPU_BUNDLE_UNINSTALL=0
set XPU_EXTRA_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d1a91e2-e8b8-40a5-8c7f-5db768a6a60c/w_intel-pti-dev_p_0.9.0.37_offline.exe
set XPU_EXTRA_PRODUCT_NAME=intel.oneapi.win.intel-pti-dev.product
set XPU_EXTRA_VERSION=0.9.0+36
set XPU_EXTRA_URL=NULL
set XPU_EXTRA_PRODUCT_NAME=intel.oneapi.win.compiler.product
set XPU_EXTRA_VERSION=2025.0.1+1226
set XPU_EXTRA_INSTALLED=0
set XPU_EXTRA_UNINSTALL=0
if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.0] (
set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/9d6d6c17-ca2d-4735-9331-99447e4a1280/intel-deep-learning-essentials-2025.0.1.28_offline.exe
set XPU_BUNDLE_PRODUCT_NAME=intel.oneapi.win.deep-learning-essentials.product
set XPU_BUNDLE_VERSION=2025.0.1+20
set XPU_BUNDLE_INSTALLED=0
set XPU_BUNDLE_UNINSTALL=0
set XPU_EXTRA_URL=NULL
set XPU_EXTRA_PRODUCT_NAME=intel.oneapi.win.compiler.product
set XPU_EXTRA_VERSION=2025.0.1+1226
set XPU_EXTRA_INSTALLED=0
set XPU_EXTRA_UNINSTALL=0
if not [%XPU_VERSION%]==[] if [%XPU_VERSION%]==[2025.1] (
set XPU_BUNDLE_URL=https://registrationcenter-download.intel.com/akdlm/IRC_NAS/1a9fff3d-04c2-4d77-8861-3d86c774b66f/intel-deep-learning-essentials-2025.1.1.26_offline.exe
set XPU_BUNDLE_VERSION=2025.1.1+23
)
:: Check if XPU bundle is target version or already installed

View File

@ -26,6 +26,7 @@ set VS2022INSTALLDIR=%VS15INSTALLDIR%
set XPU_BUNDLE_ROOT=%ProgramFiles(x86)%\Intel\oneAPI
call "%XPU_BUNDLE_ROOT%\compiler\latest\env\vars.bat"
call "%XPU_BUNDLE_ROOT%\ocloc\latest\env\vars.bat"
set USE_ONEMKL=1
IF ERRORLEVEL 1 goto :eof
if exist "%NIGHTLIES_PYTORCH_ROOT%" cd %NIGHTLIES_PYTORCH_ROOT%\..

View File

@ -15,7 +15,7 @@ fi
if [[ "$DESIRED_CUDA" == 'xpu' ]]; then
export VC_YEAR=2022
export USE_SCCACHE=0
export XPU_VERSION=2025.0
export XPU_VERSION=2025.1
export XPU_ENABLE_KINETO=1
fi

View File

@ -8,7 +8,7 @@ export VC_YEAR=2019
if [[ "$DESIRED_CUDA" == 'xpu' ]]; then
export VC_YEAR=2022
export XPU_VERSION=2025.0
export XPU_VERSION=2025.1
fi
pushd "$PYTORCH_ROOT/.ci/pytorch/"

View File

@ -44,7 +44,7 @@ runs:
retry_wait_seconds: 30
command: |
set -eu
python3 -m pip install python-dateutil==2.8.2 boto3==1.35.42 pandas==2.1.3
python3 -m pip install python-dateutil==2.8.2 boto3==1.35.42 pandas==2.1.3 dataclasses_json==0.6.7
- name: Upload utilizatoin stats to s3
shell: bash
run: |

View File

@ -25,7 +25,6 @@ ciflow_push_tags:
- ciflow/unstable
- ciflow/xpu
- ciflow/torchbench
- ciflow/autoformat
- ciflow/op-benchmark
- ciflow/pull
- ciflow/h100

View File

@ -88,17 +88,26 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
"nvidia-cufile-cu12==1.13.0.11; platform_system == 'Linux' and platform_machine == 'x86_64'"
),
"xpu": (
"intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | "
"intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | "
"intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | "
"intel-sycl-rt==2025.0.4; platform_system == 'Linux' | "
"intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | "
"intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | "
"intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | "
"intel-sycl-rt==2025.0.5; platform_system == 'Windows' | "
"tcmlib==1.2.0 | "
"umf==0.9.1 | "
"intel-pti==0.10.1"
"intel-cmplr-lib-rt==2025.1.1 | "
"intel-cmplr-lib-ur==2025.1.1 | "
"intel-cmplr-lic-rt==2025.1.1 | "
"intel-sycl-rt==2025.1.1 | "
"oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"onemkl-sycl-blas==2025.1.0 | "
"onemkl-sycl-dft==2025.1.0 | "
"onemkl-sycl-lapack==2025.1.0 | "
"onemkl-sycl-rng==2025.1.0 | "
"onemkl-sycl-sparse==2025.1.0 | "
"dpcpp-cpp-rt==2025.1.1 | "
"intel-opencl-rt==2025.1.1 | "
"mkl==2025.1.0 | "
"intel-openmp==2025.1.1 | "
"tbb==2022.1.0 | "
"tcmlib==1.3.0 | "
"umf==0.10.0 | "
"intel-pti==0.12.0"
),
}

View File

@ -1938,6 +1938,7 @@ def get_ghstack_dependent_prs(
def do_revert_prs(
repo: GitRepo,
original_pr: GitHubPR,
shas_and_prs: list[tuple[str, GitHubPR]],
*,
author_login: str,
@ -1959,9 +1960,16 @@ def do_revert_prs(
# Comment/reopen PRs
for commit_sha, pr in shas_and_prs:
revert_message = (
f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
)
revert_message = ""
if pr.pr_num == original_pr.pr_num:
revert_message += (
f"@{pr.get_pr_creator_login()} your PR has been successfully reverted."
)
else:
revert_message += (
f"@{pr.get_pr_creator_login()} your PR has been reverted as part of the stack under "
f"#{original_pr.pr_num}.\n"
)
if (
pr.has_internal_changes()
and not pr.has_no_connected_diff()
@ -2013,6 +2021,7 @@ def try_revert(
do_revert_prs(
repo,
pr,
shas_and_prs,
author_login=author_login,
extra_msg=extra_msg,

View File

@ -67,8 +67,8 @@ jobs:
pytorch-linux-jammy-py3.9-gcc11,
pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-xpu-2024.0-py3,
pytorch-linux-jammy-xpu-2025.0-py3,
pytorch-linux-jammy-xpu-2025.1-py3,
pytorch-linux-jammy-py3-clang15-asan,
pytorch-linux-jammy-py3-clang18-asan,
pytorch-linux-focal-py3-clang10-onnx,

View File

@ -565,7 +565,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-xpu
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-xpu-test: # Testing
@ -1178,7 +1178,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_10-xpu
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_10-xpu-test: # Testing
@ -1859,7 +1859,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_11-xpu
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_11-xpu-test: # Testing
@ -2472,7 +2472,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_12-xpu
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_12-xpu-test: # Testing
@ -3085,7 +3085,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_13-xpu
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_13-xpu-test: # Testing
@ -3698,7 +3698,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_13t-xpu
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_13t-xpu-test: # Testing

View File

@ -1004,7 +1004,7 @@ jobs:
GPU_ARCH_TYPE: xpu
SKIP_ALL_TESTS: 1
DESIRED_PYTHON: "3.9"
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
@ -2189,7 +2189,7 @@ jobs:
GPU_ARCH_TYPE: xpu
SKIP_ALL_TESTS: 1
DESIRED_PYTHON: "3.10"
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
@ -3374,7 +3374,7 @@ jobs:
GPU_ARCH_TYPE: xpu
SKIP_ALL_TESTS: 1
DESIRED_PYTHON: "3.11"
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
@ -4559,7 +4559,7 @@ jobs:
GPU_ARCH_TYPE: xpu
SKIP_ALL_TESTS: 1
DESIRED_PYTHON: "3.12"
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
@ -5744,7 +5744,7 @@ jobs:
GPU_ARCH_TYPE: xpu
SKIP_ALL_TESTS: 1
DESIRED_PYTHON: "3.13"
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
@ -6929,7 +6929,7 @@ jobs:
GPU_ARCH_TYPE: xpu
SKIP_ALL_TESTS: 1
DESIRED_PYTHON: "3.13t"
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-ur==2025.0.4; platform_system == 'Linux' | intel-cmplr-lic-rt==2025.0.4; platform_system == 'Linux' | intel-sycl-rt==2025.0.4; platform_system == 'Linux' | intel-cmplr-lib-rt==2025.0.5; platform_system == 'Windows' | intel-cmplr-lib-ur==2025.0.5; platform_system == 'Windows' | intel-cmplr-lic-rt==2025.0.5; platform_system == 'Windows' | intel-sycl-rt==2025.0.5; platform_system == 'Windows' | tcmlib==1.2.0 | umf==0.9.1 | intel-pti==0.10.1
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: intel-cmplr-lib-rt==2025.1.1 | intel-cmplr-lib-ur==2025.1.1 | intel-cmplr-lic-rt==2025.1.1 | intel-sycl-rt==2025.1.1 | oneccl-devel==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | oneccl==2021.15.1; platform_system == 'Linux' and platform_machine == 'x86_64' | impi-rt==2021.15.0; platform_system == 'Linux' and platform_machine == 'x86_64' | onemkl-sycl-blas==2025.1.0 | onemkl-sycl-dft==2025.1.0 | onemkl-sycl-lapack==2025.1.0 | onemkl-sycl-rng==2025.1.0 | onemkl-sycl-sparse==2025.1.0 | dpcpp-cpp-rt==2025.1.1 | intel-opencl-rt==2025.1.1 | mkl==2025.1.0 | intel-openmp==2025.1.1 | tbb==2022.1.0 | tcmlib==1.3.0 | umf==0.10.0 | intel-pti==0.12.0
steps:
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the

View File

@ -52,7 +52,7 @@ jobs:
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }}
# disable monitor in perf tests for more investigation
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -129,7 +129,9 @@ jobs:
test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-inductor-build.outputs.test-matrix }}
timeout-minutes: 720
# disable monitor in perf tests for more investigation
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -122,7 +122,7 @@ jobs:
test-matrix: ${{ needs.build.outputs.test-matrix }}
timeout-minutes: 720
# disable monitor in perf tests, next step is to enable it
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
@ -139,7 +139,7 @@ jobs:
test-matrix: ${{ needs.build.outputs.test-matrix }}
timeout-minutes: 1440
# disable monitor in perf tests, next step is to enable it
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
@ -156,7 +156,7 @@ jobs:
test-matrix: ${{ needs.build.outputs.test-matrix }}
timeout-minutes: 720
# disable monitor in perf tests for more investigation
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -103,7 +103,7 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }}
timeout-minutes: 720
# disable monitor in perf tests
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
@ -121,7 +121,7 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cpu-py3_9-gcc11-inductor-build.outputs.test-matrix }}
timeout-minutes: 720
# disable monitor in perf tests
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -123,8 +123,7 @@ jobs:
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }}
timeout-minutes: 720
# disable monitor in perf tests, next step is to enable it
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
@ -141,7 +140,7 @@ jobs:
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }}
timeout-minutes: 1440
# disable monitor in perf tests, next step is to enable it
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
@ -157,8 +156,7 @@ jobs:
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build.outputs.test-matrix }}
timeout-minutes: 720
# disable monitor in perf tests, next step is to enable it
disable-monitor: true
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -133,10 +133,6 @@ jobs:
build-environment: linux-focal-cuda12.6-py3.10-gcc9-sm80
docker-image: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build-gcp.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-inductor-build-gcp.outputs.test-matrix }}
# disable monitor in perf tests, next step is to enable it
disable-monitor: true
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
linux-jammy-cpu-py3_9-gcc11-periodic-dynamo-benchmarks-build:

View File

@ -1,10 +1,8 @@
name: Apply lint suggestions
on:
push:
tags:
- ciflow/autoformat/*
pull_request:
types: [opened, synchronize, reopened, labeled, unlabeled]
jobs:
lintrunner-autoformat:
@ -12,7 +10,7 @@ jobs:
contents: read
pull-requests: write
runs-on: lf.linux.2xlarge
if: ${{ github.repository_owner == 'pytorch' && github.event.pull_request.user.login != 'ezyang' && github.event.pull_request.user.login != 'malfet' && !startsWith(github.head_ref, 'export-') }}
if: ${{ github.repository_owner == 'pytorch' && contains(github.event.pull_request.labels.*.name, 'autoformat') }}
steps:
- name: Checkout pytorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
@ -21,12 +19,11 @@ jobs:
fetch-depth: 0
- name: Run lintrunner (nonretryable)
continue-on-error: true
# we can't run all files here because only changes around where the diff are shown in the PR UI
run: |
set -ex
python3 -m venv /tmp/venv
source /tmp/venv/bin/activate
export ADDITIONAL_LINTRUNNER_ARGS="format"
export ADDITIONAL_LINTRUNNER_ARGS="format --all-files"
bash .github/scripts/lintrunner.sh
- name: Check for changes
id: git-check

View File

@ -532,15 +532,15 @@ jobs:
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-xpu-2025_0-py3_9-build:
name: linux-jammy-xpu-2025.0-py3.9
linux-jammy-xpu-2025_1-py3_9-build:
name: linux-jammy-xpu-2025.1-py3.9
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
sync-tag: linux-xpu-2025-0-build
sync-tag: linux-xpu-2025-1-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-2025.0-py3.9
docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.0-py3
build-environment: linux-jammy-xpu-2025.1-py3.9
docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.1-py3
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },

View File

@ -138,6 +138,7 @@ jobs:
build-environment: win-vs2022-cpu-py3
cuda-version: cpu
test-matrix: ${{ needs.win-vs2022-cpu-py3-build.outputs.test-matrix }}
disable-monitor: false
secrets: inherit
win-vs2022-cuda12_6-py3-build:

View File

@ -2,7 +2,25 @@ name: Upload test stats
on:
workflow_run:
workflows: [pull, trunk, periodic, periodic-rocm-mi300, inductor, unstable, slow, unstable-periodic, inductor-periodic, rocm, rocm-mi300, inductor-micro-benchmark, inductor-micro-benchmark-x86, inductor-cu124, inductor-rocm, inductor-rocm-mi300, mac-mps]
workflows:
- pull
- trunk
- periodic
- periodic-rocm-mi300
- inductor
- unstable
- slow
- unstable-periodic
- inductor-periodic
- rocm
- rocm-mi300
- inductor-micro-benchmark
- inductor-micro-benchmark-x86
- inductor-cu124
- inductor-rocm
- inductor-rocm-mi300
- mac-mps
- linux-aarch64
types:
- completed

View File

@ -43,17 +43,38 @@ jobs:
]}
secrets: inherit
linux-jammy-xpu-2025_0-py3_9-test:
name: linux-jammy-xpu-2025.0-py3.9
linux-jammy-xpu-2025_1-py3_9-build:
name: linux-jammy-xpu-2025.1-py3.9
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
sync-tag: linux-xpu-2025-1-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-xpu-2025.1-py3.9
docker-image-name: ci-image:pytorch-linux-jammy-xpu-2025.1-py3
runner: linux.12xlarge
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
]}
secrets: inherit
linux-jammy-xpu-2025_1-py3_9-test:
name: linux-jammy-xpu-2025.1-py3.9
uses: ./.github/workflows/_xpu-test.yml
needs: linux-jammy-xpu-2025_0-py3_9-build
needs: linux-jammy-xpu-2025_1-py3_9-build
permissions:
id-token: write
contents: read
with:
build-environment: linux-jammy-xpu-2025.0-py3.9
docker-image: ${{ needs.linux-jammy-xpu-2025_0-py3_9-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-xpu-2025_0-py3_9-build.outputs.test-matrix }}
build-environment: linux-jammy-xpu-2025.1-py3.9
docker-image: ${{ needs.linux-jammy-xpu-2025_1-py3_9-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-xpu-2025_1-py3_9-build.outputs.test-matrix }}
secrets: inherit
windows-xpu-2025_0-build:
@ -67,3 +88,15 @@ jobs:
xpu-version: '2025.0'
vc-year: '2022'
secrets: inherit
windows-xpu-2025_1-build:
if: github.repository_owner == 'pytorch'
name: win-vs2022-xpu-2025_1-py3
uses: ./.github/workflows/_win-build.yml
with:
build-environment: win-vs2022-xpu-py3
cuda-version: cpu
use-xpu: true
xpu-version: '2025.1'
vc-year: '2022'
secrets: inherit

9
.gitignore vendored
View File

@ -212,15 +212,6 @@ docs/source/scripts/lr_scheduler_images/
# Compiled MATLAB
*.mex*
# IPython notebook checkpoints
.ipynb_checkpoints
# Editor temporaries
*.swn
*.swo
*.swp
*~
# NFS handle files
**/.nfs*

View File

@ -1719,3 +1719,15 @@ include_patterns = [
'torch/_dynamo/**',
]
is_formatter = false
[[linter]]
code = 'TEST_DEVICE_BIAS'
command = [
'python3',
'tools/linter/adapters/test_device_bias_linter.py',
'--',
'@{{PATHSFILE}}',
]
include_patterns = [
'test/**/test_*.py',
]

View File

@ -222,7 +222,6 @@ option(
BUILD_MOBILE_TEST
"Build C++ test binaries for mobile (ARM) targets(need gtest and gbenchmark)"
OFF)
option(BUILD_JNI "Build JNI bindings" OFF)
option(BUILD_MOBILE_AUTOGRAD
"Build autograd function in mobile build (in development)" OFF)
cmake_dependent_option(INSTALL_TEST "Install test binaries if BUILD_TEST is on"
@ -259,6 +258,8 @@ option(USE_NATIVE_ARCH "Use -march=native" OFF)
cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF)
cmake_dependent_option(USE_NCCL "Use NCCL" ON
"USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
cmake_dependent_option(USE_XCCL "Use XCCL" OFF
"USE_XPU;UNIX;NOT APPLE" OFF)
cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF)
cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF)
cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL"
@ -338,6 +339,8 @@ cmake_dependent_option(
USE_C10D_GLOO "USE C10D GLOO" ON "USE_DISTRIBUTED;USE_GLOO" OFF)
cmake_dependent_option(
USE_C10D_NCCL "USE C10D NCCL" ON "USE_DISTRIBUTED;USE_NCCL" OFF)
cmake_dependent_option(
USE_C10D_XCCL "USE C10D XCCL" ON "USE_DISTRIBUTED;USE_XCCL" OFF)
cmake_dependent_option(
USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" OFF)
cmake_dependent_option(
@ -1346,16 +1349,6 @@ if(BUILD_BINARY)
add_subdirectory(binaries)
endif()
# ---[ JNI
if(BUILD_JNI)
if(NOT MSVC)
string(APPEND CMAKE_CXX_FLAGS " -Wno-unused-variable")
endif()
set(BUILD_LIBTORCH_WITH_JNI 1)
set(FBJNI_SKIP_TESTS 1)
add_subdirectory(android/pytorch_android)
endif()
include(cmake/Summary.cmake)
caffe2_print_configuration_summary()

View File

@ -1,240 +1,5 @@
# Android
# Android Demo App
## Demo applications and tutorials
Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
Demo applications with code walk-through can be find in [this github repo](https://github.com/pytorch/android-demo-app).
## Publishing
##### Release
Release artifacts are published to jcenter:
```groovy
repositories {
jcenter()
}
# lite interpreter build
dependencies {
implementation 'org.pytorch:pytorch_android_lite:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.10.0'
}
# full jit build
dependencies {
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}
```
##### Nightly
Nightly(snapshots) builds are published every night from `master` branch to [nexus sonatype snapshots repository](https://oss.sonatype.org/#nexus-search;quick~pytorch_android)
To use them repository must be specified explicitly:
```groovy
repositories {
maven {
url "https://oss.sonatype.org/content/repositories/snapshots"
}
}
# lite interpreter build
dependencies {
...
implementation 'org.pytorch:pytorch_android_lite:1.12.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.12.0-SNAPSHOT'
...
}
# full jit build
dependencies {
...
implementation 'org.pytorch:pytorch_android:1.12.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.12.0-SNAPSHOT'
...
}
```
The current nightly(snapshots) version is the value of `VERSION_NAME` in `gradle.properties` in current folder, at this moment it is `1.8.0-SNAPSHOT`.
## Building PyTorch Android from Source
In some cases you might want to use a local build of pytorch android, for example you may build custom libtorch binary with another set of operators or to make local changes.
For this you can use `./scripts/build_pytorch_android.sh` script.
```bash
git clone https://github.com/pytorch/pytorch.git
cd pytorch
git submodule update --init --recursive
bash ./scripts/build_pytorch_android.sh
```
The workflow contains several steps:
1\. Build libtorch for android for all 4 android abis (armeabi-v7a, arm64-v8a, x86, x86_64)
2\. Create symbolic links to the results of those builds:
`android/pytorch_android/src/main/jniLibs/${abi}` to the directory with output libraries
`android/pytorch_android/src/main/cpp/libtorch_include/${abi}` to the directory with headers. These directories are used to build `libpytorch.so` library that will be loaded on android device.
3\. And finally run `gradle` in `android/pytorch_android` directory with task `assembleRelease`
Script requires that Android SDK, Android NDK and gradle are installed.
They are specified as environment variables:
`ANDROID_HOME` - path to [Android SDK](https://developer.android.com/studio/command-line/sdkmanager.html)
`ANDROID_NDK` - path to [Android NDK](https://developer.android.com/studio/projects/install-ndk). It's recommended to use NDK 21.x.
`GRADLE_HOME` - path to [gradle](https://gradle.org/releases/)
After successful build you should see the result as aar file:
```bash
$ find pytorch_android/build/ -type f -name *aar
pytorch_android/build/outputs/aar/pytorch_android.aar
pytorch_android_torchvision/build/outputs/aar/pytorch_android.aar
```
It can be used directly in android projects, as a gradle dependency:
```groovy
allprojects {
repositories {
flatDir {
dirs 'libs'
}
}
}
dependencies {
implementation(name:'pytorch_android', ext:'aar')
implementation(name:'pytorch_android_torchvision', ext:'aar')
...
implementation 'com.facebook.soloader:nativeloader:0.10.5'
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
}
```
We also have to add all transitive dependencies of our aars.
As `pytorch_android` [depends](https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/build.gradle#L76-L77) on `'com.facebook.soloader:nativeloader:0.10.5'` and `'com.facebook.fbjni:fbjni-java-only:0.2.2'`, we need to add them.
(In case of using maven dependencies they are added automatically from `pom.xml`).
You can check out [test app example](https://github.com/pytorch/pytorch/blob/master/android/test_app/app/build.gradle) that uses aars directly.
## Linking to prebuilt libtorch library from gradle dependency
In some cases, you may want to use libtorch from your android native build.
You can do it without building libtorch android, using native libraries from PyTorch android gradle dependency.
For that, you will need to add the next lines to your gradle build.
```groovy
android {
...
configurations {
extractForNativeBuild
}
...
compileOptions {
externalNativeBuild {
cmake {
arguments "-DANDROID_STL=c++_shared"
}
}
}
...
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
}
dependencies {
extractForNativeBuild('org.pytorch:pytorch_android:1.10.0')
}
task extractAARForNativeBuild {
doLast {
configurations.extractForNativeBuild.files.each {
def file = it.absoluteFile
copy {
from zipTree(file)
into "$buildDir/$file.name"
include "headers/**"
include "jni/**"
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.contains('externalNativeBuild')) {
task.dependsOn(extractAARForNativeBuild)
}
}
```
pytorch_android aar contains headers to link in `headers` folder and native libraries in `jni/$ANDROID_ABI/`.
As PyTorch native libraries use `ANDROID_STL` - we should use `ANDROID_STL=c++_shared` to have only one loaded binary of STL.
The added task will unpack them to gradle build directory.
In your native build you can link to them adding these lines to your CMakeLists.txt:
```cmake
# Relative path of gradle build directory to CMakeLists.txt
set(build_DIR ${CMAKE_SOURCE_DIR}/build)
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
set(BUILD_SUBDIR ${ANDROID_ABI})
target_include_directories(${PROJECT_NAME} PRIVATE
${PYTORCH_INCLUDE_DIRS}
)
find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
find_library(FBJNI_LIBRARY fbjni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH)
target_link_libraries(${PROJECT_NAME}
${PYTORCH_LIBRARY})
${FBJNI_LIBRARY})
```
If your CMakeLists.txt file is located in the same directory as your build.gradle, `set(build_DIR ${CMAKE_SOURCE_DIR}/build)` should work for you. But if you have another location of it, you may need to change it.
After that, you can use libtorch C++ API from your native code.
```cpp
#include <string>
#include <ATen/NativeFunctions.h>
#include <torch/script.h>
namespace pytorch_testapp_jni {
namespace {
struct JITCallGuard {
c10::InferenceMode guard;
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
}
void loadAndForwardModel(const std::string& modelPath) {
JITCallGuard guard;
torch::jit::Module module = torch::jit::load(modelPath);
module.eval();
torch::Tensor t = torch::randn({1, 3, 224, 224});
c10::IValue t_out = module.forward({t});
}
}
```
To load torchscript model for mobile we need some special setup which is placed in `struct JITCallGuard` in this example. It may change in future, you can track the latest changes keeping an eye in our [pytorch android jni code]([https://github.com/pytorch/pytorch/blob/master/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp#L28)
[Example of linking to libtorch from aar](https://github.com/pytorch/pytorch/tree/master/android/test_app)
## PyTorch Android API Javadoc
You can find more details about the PyTorch Android API in the [Javadoc](https://pytorch.org/javadoc/).
Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions.

View File

@ -1,40 +0,0 @@
allprojects {
buildscript {
ext {
minSdkVersion = 21
targetSdkVersion = 28
compileSdkVersion = 28
buildToolsVersion = '28.0.3'
coreVersion = "1.2.0"
extJUnitVersion = "1.1.1"
runnerVersion = "1.2.0"
rulesVersion = "1.2.0"
junitVersion = "4.12"
fbjniJavaOnlyVersion = "0.2.2"
soLoaderNativeLoaderVersion = "0.10.5"
}
repositories {
google()
mavenLocal()
mavenCentral()
jcenter()
}
dependencies {
classpath 'com.android.tools.build:gradle:4.1.2'
classpath 'com.vanniktech:gradle-maven-publish-plugin:0.14.2'
}
}
repositories {
google()
jcenter()
}
}
ext.deps = [
jsr305: 'com.google.code.findbugs:jsr305:3.0.1',
]

View File

@ -1,30 +0,0 @@
#!/bin/bash
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR=$PYTORCH_DIR/android
echo "PYTORCH_DIR:$PYTORCH_DIR"
source "$PYTORCH_ANDROID_DIR/common.sh"
check_android_sdk
check_gradle
parse_abis_list "$@"
build_android
# To set proxy for gradle add following lines to ./gradle/gradle.properties:
# systemProp.http.proxyHost=...
# systemProp.http.proxyPort=8080
# systemProp.https.proxyHost=...
# systemProp.https.proxyPort=8080
if [ "$CUSTOM_ABIS_LIST" = true ]; then
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
else
NDK_DEBUG=1 $GRADLE_PATH -PnativeLibsDoNotStrip=true -p $PYTORCH_ANDROID_DIR clean test_app:assembleDebug
fi
find $PYTORCH_ANDROID_DIR -type f -name *apk
find $PYTORCH_ANDROID_DIR -type f -name *apk | xargs echo "To install apk run: $ANDROID_HOME/platform-tools/adb install -r "

View File

@ -1,32 +0,0 @@
#!/bin/bash
###############################################################################
# This script tests the custom selective build flow for PyTorch Android, which
# optimizes library size by only including ops used by a specific model.
###############################################################################
set -eux
PYTORCH_DIR="$(cd $(dirname $0)/..; pwd -P)"
PYTORCH_ANDROID_DIR="${PYTORCH_DIR}/android"
BUILD_ROOT="${PYTORCH_DIR}/build_pytorch_android_custom"
source "${PYTORCH_ANDROID_DIR}/common.sh"
prepare_model_and_dump_root_ops() {
cd "${BUILD_ROOT}"
MODEL="${BUILD_ROOT}/MobileNetV2.pt"
ROOT_OPS="${BUILD_ROOT}/MobileNetV2.yaml"
python "${PYTORCH_ANDROID_DIR}/test_app/make_assets_custom.py"
cp "${MODEL}" "${PYTORCH_ANDROID_DIR}/test_app/app/src/main/assets/mobilenet2.pt"
}
# Start building
mkdir -p "${BUILD_ROOT}"
check_android_sdk
check_gradle
parse_abis_list "$@"
prepare_model_and_dump_root_ops
SELECTED_OP_LIST="${ROOT_OPS}" build_android
# TODO: change this to build test_app instead
$GRADLE_PATH -PABI_FILTERS=$ABIS_LIST -p $PYTORCH_ANDROID_DIR clean assembleRelease

View File

@ -1,74 +0,0 @@
#!/bin/bash
set -eux
##############################################################################
# Common util functions for Android build scripts.
##############################################################################
if [ -z "$PYTORCH_DIR" ]; then
echo "PYTORCH_DIR not set!"
exit 1
fi
retry () {
"$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@")
}
check_android_sdk() {
if [ -z "$ANDROID_HOME" ]; then
echo "ANDROID_HOME not set; please set it to Android sdk directory"
exit 1
fi
if [ ! -d "$ANDROID_HOME" ]; then
echo "ANDROID_HOME not a directory; did you install it under $ANDROID_HOME?"
exit 1
fi
echo "ANDROID_HOME:$ANDROID_HOME"
}
check_gradle() {
GRADLE_PATH=$PYTORCH_DIR/android/gradlew
echo "GRADLE_PATH:$GRADLE_PATH"
}
parse_abis_list() {
# sync with https://github.com/pytorch/pytorch/blob/0ca0e02685a9d033ac4f04e2fa5c8ba6dbc5ae50/android/gradle.properties#L1
ABIS_LIST="armeabi-v7a,arm64-v8a,x86,x86_64"
CUSTOM_ABIS_LIST=false
if [ $# -gt 0 ]; then
ABIS_LIST=$1
CUSTOM_ABIS_LIST=true
fi
echo "ABIS_LIST:$ABIS_LIST"
echo "CUSTOM_ABIS_LIST:$CUSTOM_ABIS_LIST"
}
build_android() {
PYTORCH_ANDROID_DIR="$PYTORCH_DIR/android"
BUILD_ROOT="${BUILD_ROOT:-$PYTORCH_DIR}"
echo "BUILD_ROOT:$BUILD_ROOT"
LIB_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/jniLibs"
INCLUDE_DIR="$PYTORCH_ANDROID_DIR/pytorch_android/src/main/cpp/libtorch_include"
# These directories only contain symbolic links.
rm -rf "$LIB_DIR" && mkdir -p "$LIB_DIR"
rm -rf "$INCLUDE_DIR" && mkdir -p "$INCLUDE_DIR"
for abi in $(echo "$ABIS_LIST" | tr ',' '\n')
do
echo "abi:$abi"
ANDROID_BUILD_ROOT="$BUILD_ROOT/build_android_$abi"
ANDROID_ABI="$abi" \
BUILD_ROOT="$ANDROID_BUILD_ROOT" \
"$PYTORCH_DIR/scripts/build_android.sh" \
-DANDROID_CCACHE="$(which ccache)" \
-DUSE_LITE_INTERPRETER_PROFILER="OFF"
echo "$abi build output lib,include at $ANDROID_BUILD_ROOT/install"
ln -s "$ANDROID_BUILD_ROOT/install/lib" "$LIB_DIR/$abi"
ln -s "$ANDROID_BUILD_ROOT/install/include" "$INCLUDE_DIR/$abi"
done
}

View File

@ -1,26 +0,0 @@
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64
VERSION_NAME=2.2.0-SNAPSHOT
GROUP=org.pytorch
MAVEN_GROUP=org.pytorch
SONATYPE_STAGING_PROFILE=orgpytorch
POM_URL=https://github.com/pytorch/pytorch/tree/master/android
POM_SCM_URL=https://github.com/pytorch/pytorch.git
POM_SCM_CONNECTION=scm:git:https://github.com/pytorch/pytorch
POM_SCM_DEV_CONNECTION=scm:git:git@github.com:pytorch/pytorch.git
POM_LICENSE_NAME=BSD 3-Clause
POM_LICENSE_URL=https://github.com/pytorch/pytorch/blob/master/LICENSE
POM_ISSUES_URL=https://github.com/pytorch/pytorch/issues
POM_LICENSE_DIST=repo
POM_DEVELOPER_ID=pytorch
POM_DEVELOPER_NAME=pytorch
# Gradle internals
org.gradle.internal.repository.max.retries=1
org.gradle.jvmargs=-XX:MaxMetaspaceSize=1024m
android.useAndroidX=true
android.enableJetifier=true
nativeLibsDoNotStrip=false
testAppAllVariantsEnabled=false

View File

@ -1,11 +0,0 @@
afterEvaluate { project ->
if (POM_PACKAGING == 'aar') {
task headersJar(type: Jar) {
archiveClassifier.set('headers')
from("$rootDir/cxx/") {
include '**/*.h'
}
}
artifacts.add('archives', headersJar)
}
}

View File

@ -1,3 +0,0 @@
apply from: rootProject.file('gradle/android_tasks.gradle')
apply plugin: 'com.vanniktech.maven.publish'

Binary file not shown.

View File

@ -1,5 +0,0 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

172
android/gradlew vendored
View File

@ -1,172 +0,0 @@
#!/usr/bin/env sh
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS=""
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin, switch paths to Windows format before running java
if $cygwin ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=$((i+1))
done
case $i in
(0) set -- ;;
(1) set -- "$args0" ;;
(2) set -- "$args0" "$args1" ;;
(3) set -- "$args0" "$args1" "$args2" ;;
(4) set -- "$args0" "$args1" "$args2" "$args3" ;;
(5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
(6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
(7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
(8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
(9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=$(save "$@")
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
cd "$(dirname "$0")"
fi
exec "$JAVACMD" "$@"

84
android/gradlew.bat vendored
View File

@ -1,84 +0,0 @@
@if "%DEBUG%" == "" @echo off
@rem ##########################################################################
@rem
@rem Gradle startup script for Windows
@rem
@rem ##########################################################################
@rem Set local scope for the variables with windows NT shell
if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=.
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
set DEFAULT_JVM_OPTS=
@rem Find java.exe
if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto init
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:findJavaFromJavaHome
set JAVA_HOME=%JAVA_HOME:"=%
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto init
echo.
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
echo.
echo Please set the JAVA_HOME variable in your environment to match the
echo location of your Java installation.
goto fail
:init
@rem Get command-line arguments, handling Windows variants
if not "%OS%" == "Windows_NT" goto win9xME_args
:win9xME_args
@rem Slurp the command line arguments.
set CMD_LINE_ARGS=
set _SKIP=2
:win9xME_args_slurp
if "x%~1" == "x" goto execute
set CMD_LINE_ARGS=%*
:execute
@rem Setup the command line
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
@rem Execute Gradle
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
:mainEnd
if "%OS%"=="Windows_NT" endlocal
:omega

View File

@ -1,184 +0,0 @@
cmake_minimum_required(VERSION 3.5)
option(BUILD_LITE_INTERPRETER "Master flag to build pytorch_jni_lite" ON)
message(
STATUS
"BUILD_LITE_INTERPRETER (pytorch_jni_lite): ${BUILD_LITE_INTERPRETER}")
if(BUILD_LITE_INTERPRETER)
project(pytorch_jni_lite CXX)
set(PYTORCH_JNI_TARGET pytorch_jni_lite)
else()
project(pytorch_jni CXX)
set(PYTORCH_JNI_TARGET pytorch_jni)
endif()
include(GNUInstallDirs)
set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.")
set(CMAKE_VERBOSE_MAKEFILE ON)
message(STATUS "ANDROID_STL:${ANDROID_STL}")
set(TRACE_ENABLED OFF)
if(DEFINED ENV{TRACE_ENABLED})
if($ENV{TRACE_ENABLED} STREQUAL "1")
message(STATUS "TRACE_ENABLED ON")
set(TRACE_ENABLED ON)
endif()
endif()
if(NOT TRACE_ENABLED)
message(STATUS "TRACE_ENABLED OFF")
endif()
set(USE_VULKAN OFF)
set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp)
if(ANDROID_ABI)
set(USE_VULKAN ON)
set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI})
set(BUILD_SUBDIR ${ANDROID_ABI})
elseif(BUILD_LIBTORCH_WITH_JNI)
# Don't need LIBTORCH_HOME if we're building from within PyTorch.
else()
# Building against a pre-built libtorch.
if(NOT LIBTORCH_HOME)
message(FATAL_ERROR
"pytorch_android requires LIBTORCH_HOME to be defined for non-Android builds.")
endif()
set(libtorch_include_DIR ${LIBTORCH_HOME}/include)
link_directories(${LIBTORCH_HOME}/lib)
set(BUILD_SUBDIR host)
endif()
message(STATUS "libtorch dir:${libtorch_DIR}")
configure_file(
${pytorch_android_DIR}/cmake_macros.h.in
${pytorch_android_DIR}/cmake_macros.h)
if(BUILD_LITE_INTERPRETER)
file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_lite.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)
else()
file(GLOB pytorch_android_SOURCES
${pytorch_android_DIR}/pytorch_jni_jit.cpp
${pytorch_android_DIR}/pytorch_jni_common.cpp
${pytorch_android_DIR}/pytorch_jni_common.h
)
endif()
add_library(${PYTORCH_JNI_TARGET} SHARED ${pytorch_android_SOURCES})
if(APPLE)
# Need to add rpath so dlopen can find dependencies.
add_custom_command(TARGET pytorch_jni
POST_BUILD COMMAND
${CMAKE_INSTALL_NAME_TOOL} -add_rpath "@loader_path"
$<TARGET_FILE:pytorch_jni>)
endif()
target_compile_options(${PYTORCH_JNI_TARGET} PRIVATE
-fexceptions
)
target_include_directories(${PYTORCH_JNI_TARGET} BEFORE
PUBLIC $<BUILD_INTERFACE:${libtorch_include_DIR}>)
set(fbjni_DIR ${CMAKE_CURRENT_LIST_DIR}/../libs/fbjni/)
set(fbjni_BUILD_DIR ${CMAKE_BINARY_DIR}/fbjni/${BUILD_SUBDIR})
add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR})
# ---[ Vulkan deps
if(USE_VULKAN)
set(Vulkan_LIBS)
set(Vulkan_INCLUDES)
include(${CMAKE_CURRENT_LIST_DIR}/../../cmake/VulkanDependencies.cmake)
endif()
if(ANDROID_ABI)
function(import_static_lib name)
add_library(${name} STATIC IMPORTED)
set_property(
TARGET ${name}
PROPERTY IMPORTED_LOCATION
${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/${name}.a)
endfunction(import_static_lib)
import_static_lib(libtorch)
import_static_lib(libtorch_cpu)
import_static_lib(libc10)
import_static_lib(libnnpack)
import_static_lib(libXNNPACK)
import_static_lib(libmicrokernels-prod)
import_static_lib(libpytorch_qnnpack)
import_static_lib(libpthreadpool)
import_static_lib(libeigen_blas)
import_static_lib(libcpuinfo)
import_static_lib(libclog)
# Link most things statically on Android.
set(pytorch_jni_LIBS
fbjni
-Wl,--gc-sections
-Wl,--whole-archive
libtorch
libtorch_cpu
-Wl,--no-whole-archive
libc10
libnnpack
libXNNPACK
libmicrokernels-prod
libpytorch_qnnpack
libpthreadpool
libeigen_blas
libcpuinfo
libclog
)
else()
# Prefer dynamic linking on the host
set(pytorch_jni_LIBS
fbjni
torch
torch_cpu
c10
cpuinfo
)
if(USE_NNPACK)
list(APPEND pytorch_jni_LIBS nnpack)
endif()
if(USE_XNNPACK)
list(APPEND pytorch_jni_LIBS XNNPACK)
list(APPEND pytorch_jni_LIBS microkernels-prod)
endif()
if(USE_SYSTEM_PTHREADPOOL)
list(APPEND pytorch_jni_LIBS pthreadpool)
endif()
if(USE_PYTORCH_QNNPACK)
list(APPEND pytorch_jni_LIBS pytorch_qnnpack)
list(APPEND pytorch_jni_LIBS clog)
endif()
endif()
if(USE_VULKAN)
list(APPEND pytorch_jni_LIBS ${Vulkan_LIBS})
endif()
target_link_libraries(${PYTORCH_JNI_TARGET} ${pytorch_jni_LIBS})
install(TARGETS ${PYTORCH_JNI_TARGET}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) #For windows
if(MSVC)
install(FILES $<TARGET_PDB_FILE:pytorch_jni> DESTINATION ${CMAKE_INSTALL_LIBDIR} OPTIONAL)
install(TARGETS ${PYTORCH_JNI_TARGET} DESTINATION ${CMAKE_INSTALL_LIBDIR})
endif()

View File

@ -1,163 +0,0 @@
apply plugin: 'com.android.library'
apply plugin: 'maven'
android {
compileSdkVersion rootProject.compileSdkVersion
buildToolsVersion rootProject.buildToolsVersion
defaultConfig {
minSdkVersion rootProject.minSdkVersion
targetSdkVersion rootProject.targetSdkVersion
versionCode 0
versionName "0.1"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
ndk {
abiFilters ABI_FILTERS.split(",")
}
externalNativeBuild {
cmake {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
arguments "-DANDROID_STL=c++_shared", "-DBUILD_LITE_INTERPRETER=OFF", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
} else {
arguments "-DANDROID_STL=c++_shared", "-DUSE_LITE_INTERPRETER_PROFILER=OFF"
}
}
}
}
buildTypes {
debug {
minifyEnabled false
debuggable true
}
release {
minifyEnabled false
}
}
sourceSets {
main {
java {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
println 'Build pytorch_jni'
exclude 'org/pytorch/LiteModuleLoader.java'
exclude 'org/pytorch/LiteNativePeer.java'
} else {
println 'Build pytorch_jni_lite'
}
}
jniLibs.srcDirs = ['src/main/jniLibs']
manifest.srcFile 'src/main/AndroidManifest.xml'
}
androidTest {
java {
if(System.env.BUILD_LITE_INTERPRETER == '0') {
println 'Build test for full jit (pytorch_jni)'
exclude 'org/pytorch/PytorchHostTests.java'
exclude 'org/pytorch/PytorchLiteInstrumentedTests.java'
exclude 'org/pytorch/suite/PytorchLiteInstrumentedTestSuite.java'
} else {
println 'Build test for lite interpreter (pytorch_jni_lite)'
exclude 'org/pytorch/PytorchHostTests.java'
exclude 'org/pytorch/PytorchInstrumentedTests.java'
exclude 'org/pytorch/suite/PytorchInstrumentedTestSuite.java'
}
}
}
}
externalNativeBuild {
cmake {
path "CMakeLists.txt"
}
}
packagingOptions {
if (nativeLibsDoNotStrip.toBoolean()) {
doNotStrip "**/*.so"
logger.warn('WARNING: nativeLibsDoNotStrip==true; debug symbols included')
}
}
useLibrary 'android.test.runner'
useLibrary 'android.test.base'
useLibrary 'android.test.mock'
}
dependencies {
implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion
implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion
testImplementation 'junit:junit:' + rootProject.junitVersion
testImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'junit:junit:' + rootProject.junitVersion
androidTestImplementation 'androidx.test:core:' + rootProject.coreVersion
androidTestImplementation 'androidx.test.ext:junit:' + rootProject.extJUnitVersion
androidTestImplementation 'androidx.test:rules:' + rootProject.rulesVersion
androidTestImplementation 'androidx.test:runner:' + rootProject.runnerVersion
}
apply from: rootProject.file('gradle/release.gradle')
task sourcesJar(type: Jar) {
from android.sourceSets.main.java.srcDirs
classifier = 'sources'
}
def getLibtorchHeadersDir() {
def abi = ABI_FILTERS.split(",")[0]
return "$rootDir/pytorch_android/src/main/cpp/libtorch_include/$abi"
}
afterEvaluate {
if (POM_PACKAGING == 'aar') {
android.libraryVariants.all { variant ->
variant.outputs.each { output ->
File f = output.outputFile
if (f.name.endsWith(".aar")) {
output.assemble.finalizedBy addFolderToAarTask(
"addHeadersToAar" + variant.name,
f.path,
getLibtorchHeadersDir(),
"headers")
}
}
}
}
}
tasks.whenTaskAdded { task ->
if (task.name.startsWith("bundle") && task.name.endsWith("Aar")) {
doLast {
addFolderToAar("addHeadersTo" + task.name, task.archivePath, getLibtorchHeadersDir(), 'headers')
}
}
}
def addFolderToAarTask(taskName, aarPath, folderPath, folderPathInAar) {
return tasks.register(taskName) {
doLast {
addFolderToAar(taskName, aarPath, folderPath, folderPathInAar)
}
}
}
def addFolderToAar(taskName, aarPath, folderPath, folderPathInAar) {
def tmpDir = file("${buildDir}/${taskName}")
tmpDir.mkdir()
def tmpDirFolder = file("${tmpDir.path}/${folderPathInAar}")
tmpDirFolder.mkdir()
copy {
from zipTree(aarPath)
into tmpDir
}
copy {
from fileTree(folderPath)
into tmpDirFolder
}
ant.zip(destfile: aarPath) {
fileset(dir: tmpDir.path)
}
delete tmpDir
}
artifacts.add('archives', sourcesJar)

View File

@ -1,20 +0,0 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/jit.h>
#include <torch/script.h>
#include <fstream>
#include <iostream>
#include <string>
int main(int argc, char* argv[]) {
std::string input_file_path{argv[1]};
std::string output_file_path{argv[2]};
std::ifstream ifs(input_file_path);
std::stringstream buffer;
buffer << ifs.rdbuf();
torch::jit::Module m("TestModule");
m.define(buffer.str());
m.save(output_file_path);
}

View File

@ -1,151 +0,0 @@
from typing import Optional
import torch
from torch import Tensor
OUTPUT_DIR = "src/androidTest/assets/"
def scriptAndSave(module, fileName):
print("-" * 80)
script_module = torch.jit.script(module)
print(script_module.graph)
outputFileName = OUTPUT_DIR + fileName
# note that the lite interpreter model can also be used in full JIT
script_module._save_for_lite_interpreter(outputFileName)
print("Saved to " + outputFileName)
print("=" * 80)
class Test(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input):
return None
@torch.jit.script_method
def eqBool(self, input: bool) -> bool:
return input
@torch.jit.script_method
def eqInt(self, input: int) -> int:
return input
@torch.jit.script_method
def eqFloat(self, input: float) -> float:
return input
@torch.jit.script_method
def eqStr(self, input: str) -> str:
return input
@torch.jit.script_method
def eqTensor(self, input: Tensor) -> Tensor:
return input
@torch.jit.script_method
def eqDictStrKeyIntValue(self, input: dict[str, int]) -> dict[str, int]:
return input
@torch.jit.script_method
def eqDictIntKeyIntValue(self, input: dict[int, int]) -> dict[int, int]:
return input
@torch.jit.script_method
def eqDictFloatKeyIntValue(self, input: dict[float, int]) -> dict[float, int]:
return input
@torch.jit.script_method
def listIntSumReturnTuple(self, input: list[int]) -> tuple[list[int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
@torch.jit.script_method
def listBoolConjunction(self, input: list[bool]) -> bool:
res = True
for x in input:
res = res and x
return res
@torch.jit.script_method
def listBoolDisjunction(self, input: list[bool]) -> bool:
res = False
for x in input:
res = res or x
return res
@torch.jit.script_method
def tupleIntSumReturnTuple(
self, input: tuple[int, int, int]
) -> tuple[tuple[int, int, int], int]:
sum = 0
for x in input:
sum += x
return (input, sum)
@torch.jit.script_method
def optionalIntIsNone(self, input: Optional[int]) -> bool:
return input is None
@torch.jit.script_method
def intEq0None(self, input: int) -> Optional[int]:
if input == 0:
return None
return input
@torch.jit.script_method
def str3Concat(self, input: str) -> str:
return input + input + input
@torch.jit.script_method
def newEmptyShapeWithItem(self, input):
return torch.tensor([int(input.item())])[0]
@torch.jit.script_method
def testAliasWithOffset(self) -> list[Tensor]:
x = torch.tensor([100, 200])
a = [x[0], x[1]]
return a
@torch.jit.script_method
def testNonContiguous(self):
x = torch.tensor([100, 200, 300])[::2]
assert not x.is_contiguous()
assert x[0] == 100
assert x[1] == 300
return x
@torch.jit.script_method
def conv2d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv2d(x, w)
if toChannelsLast:
r = r.contiguous(memory_format=torch.channels_last)
else:
r = r.contiguous()
return r
@torch.jit.script_method
def conv3d(self, x: Tensor, w: Tensor, toChannelsLast: bool) -> Tensor:
r = torch.nn.functional.conv3d(x, w)
if toChannelsLast:
r = r.contiguous(memory_format=torch.channels_last_3d)
else:
r = r.contiguous()
return r
@torch.jit.script_method
def contiguous(self, x: Tensor) -> Tensor:
return x.contiguous()
@torch.jit.script_method
def contiguousChannelsLast(self, x: Tensor) -> Tensor:
return x.contiguous(memory_format=torch.channels_last)
@torch.jit.script_method
def contiguousChannelsLast3d(self, x: Tensor) -> Tensor:
return x.contiguous(memory_format=torch.channels_last_3d)
scriptAndSave(Test(), "test.pt")

View File

@ -1,4 +0,0 @@
POM_NAME=pytorch_android_lite pytorch android api
POM_DESCRIPTION=pytorch_android_lite pytorch android api
POM_ARTIFACT_ID=pytorch_android_lite
POM_PACKAGING=aar

View File

@ -1,42 +0,0 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the Apache-2 license found in the
// LICENSE file in the root directory of this source tree.
plugins {
id 'java-library'
}
repositories {
mavenLocal()
jcenter()
}
sourceSets {
main {
java {
srcDir '../src/main/java'
exclude 'org/pytorch/PyTorchAndroid.java'
exclude 'org/pytorch/LitePyTorchAndroid.java'
exclude 'org/pytorch/LiteModuleLoader.java'
exclude 'org/pytorch/LiteNativePeer.java'
}
}
test {
java {
srcDir '../src/androidTest/java'
exclude '**/PytorchInstrumented*'
exclude '**/PytorchLiteInstrumented*'
}
resources.srcDirs = ["../src/androidTest/assets"]
}
}
dependencies {
compileOnly 'com.google.code.findbugs:jsr305:3.0.1'
implementation 'com.facebook.soloader:nativeloader:0.10.1'
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
testImplementation 'junit:junit:4.12'
}
apply from: rootProject.file('gradle/release.gradle')

View File

@ -1,4 +0,0 @@
POM_NAME=pytorch_java_only pytorch java api
POM_DESCRIPTION=pytorch_java_only pytorch java api
POM_ARTIFACT_ID=pytorch_java_only
POM_PACKAGING=jar

View File

@ -1,21 +0,0 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <gtest/gtest.h>
#include <ATen/core/type_factory.h>
#include "caffe2/android/pytorch_android/src/main/cpp/pytorch_jni_common.h"
using namespace ::testing;
TEST(pytorch_jni_common_test, newJIValueFromAtIValue) {
auto dict = c10::impl::GenericDict(
c10::dynT<c10::IntType>(), c10::dynT<c10::StringType>());
auto dictCallback = [](auto&&) {
return facebook::jni::local_ref<pytorch_jni::JIValue>{};
};
EXPECT_NO_THROW(pytorch_jni::JIValue::newJIValueFromAtIValue(
dict, dictCallback, dictCallback));
}

View File

@ -1,25 +0,0 @@
package org.pytorch;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.Objects;
public class PytorchHostTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return Module.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
Path tempFile = Files.createTempFile("test", ".pt");
try (InputStream resource =
Objects.requireNonNull(getClass().getClassLoader().getResourceAsStream("test.pt"))) {
Files.copy(resource, tempFile, StandardCopyOption.REPLACE_EXISTING);
}
return tempFile.toAbsolutePath().toString();
}
}

View File

@ -1,42 +0,0 @@
package org.pytorch;
import android.content.Context;
import androidx.test.InstrumentationRegistry;
import androidx.test.runner.AndroidJUnit4;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public class PytorchInstrumentedTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return Module.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = appContext.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
throw e;
}
}
}

View File

@ -1,42 +0,0 @@
package org.pytorch;
import android.content.Context;
import androidx.test.InstrumentationRegistry;
import androidx.test.runner.AndroidJUnit4;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import org.junit.runner.RunWith;
@RunWith(AndroidJUnit4.class)
public class PytorchLiteInstrumentedTests extends PytorchTestBase {
@Override
protected Module loadModel(String path) throws IOException {
return LiteModuleLoader.load(assetFilePath(path));
}
private String assetFilePath(String assetName) throws IOException {
final Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
File file = new File(appContext.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = appContext.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
throw e;
}
}
}

View File

@ -1,694 +0,0 @@
package org.pytorch;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.junit.Test;
import org.junit.Ignore;
public abstract class PytorchTestBase {
private static final String TEST_MODULE_ASSET_NAME = "android_api_module.ptl";
@Test
public void testForwardNull() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(Tensor.fromBlob(Tensor.allocateByteBuffer(1), new long[] {1}));
assertTrue(input.isTensor());
final IValue output = module.forward(input);
assertTrue(output.isNull());
}
@Test
public void testEqBool() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (boolean value : new boolean[] {false, true}) {
final IValue input = IValue.from(value);
assertTrue(input.isBool());
assertTrue(value == input.toBool());
final IValue output = module.runMethod("eqBool", input);
assertTrue(output.isBool());
assertTrue(value == output.toBool());
}
}
@Test
public void testEqInt() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (long value : new long[] {Long.MIN_VALUE, -1024, -1, 0, 1, 1024, Long.MAX_VALUE}) {
final IValue input = IValue.from(value);
assertTrue(input.isLong());
assertTrue(value == input.toLong());
final IValue output = module.runMethod("eqInt", input);
assertTrue(output.isLong());
assertTrue(value == output.toLong());
}
}
@Test
public void testEqFloat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
double[] values =
new double[] {
-Double.MAX_VALUE,
Double.MAX_VALUE,
-Double.MIN_VALUE,
Double.MIN_VALUE,
-Math.exp(1.d),
-Math.sqrt(2.d),
-3.1415f,
3.1415f,
-1,
0,
1,
};
for (double value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isDouble());
assertTrue(value == input.toDouble());
final IValue output = module.runMethod("eqFloat", input);
assertTrue(output.isDouble());
assertTrue(value == output.toDouble());
}
}
@Test
public void testEqTensor() throws IOException {
final long[] inputTensorShape = new long[] {1, 3, 224, 224};
final long numElements = Tensor.numel(inputTensorShape);
final float[] inputTensorData = new float[(int) numElements];
for (int i = 0; i < numElements; ++i) {
inputTensorData[i] = i;
}
final Tensor inputTensor = Tensor.fromBlob(inputTensorData, inputTensorShape);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue input = IValue.from(inputTensor);
assertTrue(input.isTensor());
assertTrue(inputTensor == input.toTensor());
final IValue output = module.runMethod("eqTensor", input);
assertTrue(output.isTensor());
final Tensor outputTensor = output.toTensor();
assertNotNull(outputTensor);
assertArrayEquals(inputTensorShape, outputTensor.shape());
float[] outputData = outputTensor.getDataAsFloatArray();
for (int i = 0; i < numElements; i++) {
assertTrue(inputTensorData[i] == outputData[i]);
}
}
@Test
public void testEqDictIntKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<Long, IValue> inputMap = new HashMap<>();
inputMap.put(Long.MIN_VALUE, IValue.from(-Long.MIN_VALUE));
inputMap.put(Long.MAX_VALUE, IValue.from(-Long.MAX_VALUE));
inputMap.put(0l, IValue.from(0l));
inputMap.put(1l, IValue.from(-1l));
inputMap.put(-1l, IValue.from(1l));
final IValue input = IValue.dictLongKeyFrom(inputMap);
assertTrue(input.isDictLongKey());
final IValue output = module.runMethod("eqDictIntKeyIntValue", input);
assertTrue(output.isDictLongKey());
final Map<Long, IValue> outputMap = output.toDictLongKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<Long, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testEqDictStrKeyIntValue() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final Map<String, IValue> inputMap = new HashMap<>();
inputMap.put("long_min_value", IValue.from(Long.MIN_VALUE));
inputMap.put("long_max_value", IValue.from(Long.MAX_VALUE));
inputMap.put("long_0", IValue.from(0l));
inputMap.put("long_1", IValue.from(1l));
inputMap.put("long_-1", IValue.from(-1l));
final IValue input = IValue.dictStringKeyFrom(inputMap);
assertTrue(input.isDictStringKey());
final IValue output = module.runMethod("eqDictStrKeyIntValue", input);
assertTrue(output.isDictStringKey());
final Map<String, IValue> outputMap = output.toDictStringKey();
assertTrue(inputMap.size() == outputMap.size());
for (Map.Entry<String, IValue> entry : inputMap.entrySet()) {
assertTrue(outputMap.get(entry.getKey()).toLong() == entry.getValue().toLong());
}
}
@Test
public void testListIntSumReturnTuple() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
for (int n : new int[] {0, 1, 128}) {
long[] a = new long[n];
long sum = 0;
for (int i = 0; i < n; i++) {
a[i] = i;
sum += a[i];
}
final IValue input = IValue.listFrom(a);
assertTrue(input.isLongList());
final IValue output = module.runMethod("listIntSumReturnTuple", input);
assertTrue(output.isTuple());
assertTrue(2 == output.toTuple().length);
IValue output0 = output.toTuple()[0];
IValue output1 = output.toTuple()[1];
assertArrayEquals(a, output0.toLongList());
assertTrue(sum == output1.toLong());
}
}
@Test
public void testOptionalIntIsNone() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertFalse(module.runMethod("optionalIntIsNone", IValue.from(1l)).toBool());
assertTrue(module.runMethod("optionalIntIsNone", IValue.optionalNull()).toBool());
}
@Test
public void testIntEq0None() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
assertTrue(module.runMethod("intEq0None", IValue.from(0l)).isNull());
assertTrue(module.runMethod("intEq0None", IValue.from(1l)).toLong() == 1l);
}
@Test(expected = IllegalArgumentException.class)
public void testRunUndefinedMethod() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
module.runMethod("test_undefined_method_throws_exception");
}
@Test
public void testTensorMethods() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
int[] ints = new int[numel];
float[] floats = new float[numel];
byte[] bytes = new byte[numel];
for (int i = 0; i < numel; i++) {
bytes[i] = (byte) ((i % 255) - 128);
ints[i] = i;
floats[i] = i / 1000.f;
}
Tensor tensorBytes = Tensor.fromBlob(bytes, shape);
assertTrue(tensorBytes.dtype() == DType.INT8);
assertArrayEquals(bytes, tensorBytes.getDataAsByteArray());
Tensor tensorInts = Tensor.fromBlob(ints, shape);
assertTrue(tensorInts.dtype() == DType.INT32);
assertArrayEquals(ints, tensorInts.getDataAsIntArray());
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
float[] floatsOut = tensorFloats.getDataAsFloatArray();
assertTrue(floatsOut.length == numel);
for (int i = 0; i < numel; i++) {
assertTrue(floats[i] == floatsOut[i]);
}
}
@Test(expected = IllegalStateException.class)
public void testTensorIllegalStateOnWrongType() {
long[] shape = new long[] {1, 3, 224, 224};
final int numel = (int) Tensor.numel(shape);
float[] floats = new float[numel];
Tensor tensorFloats = Tensor.fromBlob(floats, shape);
assertTrue(tensorFloats.dtype() == DType.FLOAT32);
tensorFloats.getDataAsByteArray();
}
@Test
public void testEqString() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("eqStr", input);
assertTrue(output.isString());
assertTrue(value.equals(output.toStr()));
}
}
@Test
public void testStr3Concat() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
String[] values =
new String[] {
"smoketest",
"проверка не латинских символов", // not latin symbols check
"#@$!@#)($*!@#$)(!@*#$"
};
for (String value : values) {
final IValue input = IValue.from(value);
assertTrue(input.isString());
assertTrue(value.equals(input.toStr()));
final IValue output = module.runMethod("str3Concat", input);
assertTrue(output.isString());
String expectedOutput =
new StringBuilder().append(value).append(value).append(value).toString();
assertTrue(expectedOutput.equals(output.toStr()));
}
}
@Test
public void testEmptyShape() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final long someNumber = 43;
final IValue input = IValue.from(Tensor.fromBlob(new long[] {someNumber}, new long[] {}));
final IValue output = module.runMethod("newEmptyShapeWithItem", input);
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {}, value.shape());
assertArrayEquals(new long[] {someNumber}, value.getDataAsLongArray());
}
@Test
public void testAliasWithOffset() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testAliasWithOffset");
assertTrue(output.isTensorList());
Tensor[] tensors = output.toTensorList();
assertEquals(100, tensors[0].getDataAsLongArray()[0]);
assertEquals(200, tensors[1].getDataAsLongArray()[0]);
}
@Test
public void testNonContiguous() throws IOException {
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue output = module.runMethod("testNonContiguous");
assertTrue(output.isTensor());
Tensor value = output.toTensor();
assertArrayEquals(new long[] {2}, value.shape());
assertArrayEquals(new long[] {100, 300}, value.getDataAsLongArray());
}
@Test
public void testChannelsLast() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] data = new long[] {1, 11, 101, 2, 12, 102, 3, 13, 103, 4, 14, 104};
Tensor inputNHWC = Tensor.fromBlob(data, inputShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW = module.runMethod("contiguous", IValue.from(inputNHWC));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {1, 2, 3, 4, 11, 12, 13, 14, 101, 102, 103, 104});
final IValue outputNHWC = module.runMethod("contiguousChannelsLast", IValue.from(inputNHWC));
assertIValueTensor(outputNHWC, MemoryFormat.CHANNELS_LAST, inputShape, data);
}
@Test
public void testChannelsLast3d() throws IOException {
long[] shape = new long[] {1, 2, 2, 2, 2};
long[] dataNCHWD = new long[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
long[] dataNHWDC = new long[] {1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 8, 16};
Tensor inputNHWDC = Tensor.fromBlob(dataNHWDC, shape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHWD = module.runMethod("contiguous", IValue.from(inputNHWDC));
assertIValueTensor(outputNCHWD, MemoryFormat.CONTIGUOUS, shape, dataNCHWD);
Tensor inputNCHWD = Tensor.fromBlob(dataNCHWD, shape, MemoryFormat.CONTIGUOUS);
final IValue outputNHWDC =
module.runMethod("contiguousChannelsLast3d", IValue.from(inputNCHWD));
assertIValueTensor(outputNHWDC, MemoryFormat.CHANNELS_LAST_3D, shape, dataNHWDC);
}
@Test
public void testChannelsLastConv2d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2};
long[] dataNCHW = new long[] {
111, 112,
121, 122,
211, 212,
221, 222,
311, 312,
321, 322};
Tensor inputNCHW = Tensor.fromBlob(dataNCHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNHWC = new long[] {
111, 211, 311, 112, 212, 312,
121, 221, 321, 122, 222, 322};
Tensor inputNHWC = Tensor.fromBlob(dataNHWC, inputShape, MemoryFormat.CHANNELS_LAST);
long[] weightShape = new long[] {3, 3, 1, 1};
long[] dataWeightOIHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNCHW = Tensor.fromBlob(dataWeightOIHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightOHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1};
Tensor wNHWC = Tensor.fromBlob(dataWeightOHWI, weightShape, MemoryFormat.CHANNELS_LAST);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCHW =
module.runMethod("conv2d", IValue.from(inputNCHW), IValue.from(wNCHW), IValue.from(false));
assertIValueTensor(
outputNCHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2},
new long[] {
2*111, 2*112,
2*121, 2*122,
211, 212,
221, 222,
-311, -312,
-321, -322});
final IValue outputNHWC =
module.runMethod("conv2d", IValue.from(inputNHWC), IValue.from(wNHWC), IValue.from(true));
assertIValueTensor(
outputNHWC,
MemoryFormat.CHANNELS_LAST,
new long[] {1, 3, 2, 2},
new long[] {
2*111, 211, -311, 2*112, 212, -312,
2*121, 221, -321, 2*122, 222, -322});
}
@Test
public void testChannelsLastConv3d() throws IOException {
long[] inputShape = new long[] {1, 3, 2, 2, 2};
long[] dataNCDHW = new long[] {
1111, 1112,
1121, 1122,
1211, 1212,
1221, 1222,
2111, 2112,
2121, 2122,
2211, 2212,
2221, 2222,
3111, 3112,
3121, 3122,
3211, 3212,
3221, 3222};
Tensor inputNCDHW = Tensor.fromBlob(dataNCDHW, inputShape, MemoryFormat.CONTIGUOUS);
long[] dataNDHWC = new long[] {
1111, 2111, 3111,
1112, 2112, 3112,
1121, 2121, 3121,
1122, 2122, 3122,
1211, 2211, 3211,
1212, 2212, 3212,
1221, 2221, 3221,
1222, 2222, 3222};
Tensor inputNDHWC = Tensor.fromBlob(dataNDHWC, inputShape, MemoryFormat.CHANNELS_LAST_3D);
long[] weightShape = new long[] {3, 3, 1, 1, 1};
long[] dataWeightOIDHW = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNCDHW = Tensor.fromBlob(dataWeightOIDHW, weightShape, MemoryFormat.CONTIGUOUS);
long[] dataWeightODHWI = new long[] {
2, 0, 0,
0, 1, 0,
0, 0, -1,
};
Tensor wNDHWC = Tensor.fromBlob(dataWeightODHWI, weightShape, MemoryFormat.CHANNELS_LAST_3D);
final Module module = loadModel(TEST_MODULE_ASSET_NAME);
final IValue outputNCDHW =
module.runMethod("conv3d", IValue.from(inputNCDHW), IValue.from(wNCDHW), IValue.from(false));
assertIValueTensor(
outputNCDHW,
MemoryFormat.CONTIGUOUS,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2*1112, 2*1121, 2*1122,
2*1211, 2*1212, 2*1221, 2*1222,
2111, 2112, 2121, 2122,
2211, 2212, 2221, 2222,
-3111, -3112, -3121, -3122,
-3211, -3212, -3221, -3222});
final IValue outputNDHWC =
module.runMethod("conv3d", IValue.from(inputNDHWC), IValue.from(wNDHWC), IValue.from(true));
assertIValueTensor(
outputNDHWC,
MemoryFormat.CHANNELS_LAST_3D,
new long[] {1, 3, 2, 2, 2},
new long[] {
2*1111, 2111, -3111, 2*1112, 2112, -3112,
2*1121, 2121, -3121, 2*1122, 2122, -3122,
2*1211, 2211, -3211, 2*1212, 2212, -3212,
2*1221, 2221, -3221, 2*1222, 2222, -3222});
}
@Test
public void testMobileNetV2() throws IOException {
try {
final Module module = loadModel("mobilenet_v2.ptl");
final IValue inputs = module.runMethod("get_all_bundled_inputs");
assertTrue(inputs.isList());
final IValue input = inputs.toList()[0];
assertTrue(input.isTuple());
module.forward(input.toTuple()[0]);
assertTrue(true);
} catch (Exception ex) {
assertTrue("failed to run MobileNetV2 " + ex.getMessage(), false);
}
}
@Test
public void testPointwiseOps() throws IOException {
runModel("pointwise_ops");
}
@Test
public void testReductionOps() throws IOException {
runModel("reduction_ops");
}
@Test
public void testComparisonOps() throws IOException {
runModel("comparison_ops");
}
@Test
public void testOtherMathOps() throws IOException {
runModel("other_math_ops");
}
@Test
@Ignore
public void testSpectralOps() throws IOException {
// NB: This model fails without lite interpreter. The error is as follows:
// RuntimeError: stft requires the return_complex parameter be given for real inputs
runModel("spectral_ops");
}
@Test
public void testBlasLapackOps() throws IOException {
runModel("blas_lapack_ops");
}
@Test
public void testSamplingOps() throws IOException {
runModel("sampling_ops");
}
@Test
public void testTensorOps() throws IOException {
runModel("tensor_general_ops");
}
@Test
public void testTensorCreationOps() throws IOException {
runModel("tensor_creation_ops");
}
@Test
public void testTensorIndexingOps() throws IOException {
runModel("tensor_indexing_ops");
}
@Test
public void testTensorTypingOps() throws IOException {
runModel("tensor_typing_ops");
}
@Test
public void testTensorViewOps() throws IOException {
runModel("tensor_view_ops");
}
@Test
public void testConvolutionOps() throws IOException {
runModel("convolution_ops");
}
@Test
public void testPoolingOps() throws IOException {
runModel("pooling_ops");
}
@Test
public void testPaddingOps() throws IOException {
runModel("padding_ops");
}
@Test
public void testActivationOps() throws IOException {
runModel("activation_ops");
}
@Test
public void testNormalizationOps() throws IOException {
runModel("normalization_ops");
}
@Test
public void testRecurrentOps() throws IOException {
runModel("recurrent_ops");
}
@Test
public void testTransformerOps() throws IOException {
runModel("transformer_ops");
}
@Test
public void testLinearOps() throws IOException {
runModel("linear_ops");
}
@Test
public void testDropoutOps() throws IOException {
runModel("dropout_ops");
}
@Test
public void testSparseOps() throws IOException {
runModel("sparse_ops");
}
@Test
public void testDistanceFunctionOps() throws IOException {
runModel("distance_function_ops");
}
@Test
public void testLossFunctionOps() throws IOException {
runModel("loss_function_ops");
}
@Test
public void testVisionFunctionOps() throws IOException {
runModel("vision_function_ops");
}
@Test
public void testShuffleOps() throws IOException {
runModel("shuffle_ops");
}
@Test
public void testNNUtilsOps() throws IOException {
runModel("nn_utils_ops");
}
@Test
public void testQuantOps() throws IOException {
runModel("general_quant_ops");
}
@Test
public void testDynamicQuantOps() throws IOException {
runModel("dynamic_quant_ops");
}
@Test
public void testStaticQuantOps() throws IOException {
runModel("static_quant_ops");
}
@Test
public void testFusedQuantOps() throws IOException {
runModel("fused_quant_ops");
}
@Test
public void testTorchScriptBuiltinQuantOps() throws IOException {
runModel("torchscript_builtin_ops");
}
@Test
public void testTorchScriptCollectionQuantOps() throws IOException {
runModel("torchscript_collection_ops");
}
static void assertIValueTensor(
final IValue ivalue,
final MemoryFormat memoryFormat,
final long[] expectedShape,
final long[] expectedData) {
assertTrue(ivalue.isTensor());
Tensor t = ivalue.toTensor();
assertEquals(memoryFormat, t.memoryFormat());
assertArrayEquals(expectedShape, t.shape());
assertArrayEquals(expectedData, t.getDataAsLongArray());
}
void runModel(final String name) throws IOException {
final Module storage_module = loadModel(name + ".ptl");
storage_module.forward();
// TODO enable this once the on-the-fly script is ready
// final Module on_the_fly_module = loadModel(name + "_temp.ptl");
// on_the_fly_module.forward();
assertTrue(true);
}
protected abstract Module loadModel(String assetName) throws IOException;
}

View File

@ -1,9 +0,0 @@
package org.pytorch.suite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.pytorch.PytorchInstrumentedTests;
@RunWith(Suite.class)
@Suite.SuiteClasses({PytorchInstrumentedTests.class})
public class PytorchInstrumentedTestSuite {}

View File

@ -1,9 +0,0 @@
package org.pytorch.suite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
import org.pytorch.PytorchLiteInstrumentedTests;
@RunWith(Suite.class)
@Suite.SuiteClasses({PytorchLiteInstrumentedTests.class})
public class PytorchLiteInstrumentedTestSuite {}

View File

@ -1 +0,0 @@
<manifest package="org.pytorch" />

View File

@ -1,3 +0,0 @@
#pragma once
/* #undef TRACE_ENABLED */

View File

@ -1,3 +0,0 @@
#pragma once
#cmakedefine TRACE_ENABLED

View File

@ -1,697 +0,0 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <c10/core/MemoryFormat.h>
#include <c10/util/irange.h>
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#include "pytorch_jni_common.h"
#if defined(__ANDROID__)
#ifndef USE_PTHREADPOOL
#define USE_PTHREADPOOL
#endif /* USE_PTHREADPOOL */
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
#endif
namespace pytorch_jni {
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode) {
if (deviceJniCode == kDeviceCPU) {
return at::kCPU;
} else if (deviceJniCode == kDeviceVulkan) {
return at::kVulkan;
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException, "Unknown device");
}
bool Trace::is_initialized_ = false;
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
Trace::fp_ATrace_beginSection Trace::ATrace_beginSection;
Trace::fp_ATrace_endSection Trace::ATrace_endSection;
#endif
void Trace::init() {
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
void* lib = dlopen("libandroid.so", RTLD_NOW || RTLD_LOCAL);
if (lib != NULL) {
Trace::ATrace_beginSection = reinterpret_cast<fp_ATrace_beginSection>(
dlsym(lib, "ATrace_beginSection"));
Trace::ATrace_endSection =
reinterpret_cast<fp_ATrace_endSection>(dlsym(lib, "ATrace_endSection"));
}
#endif
}
// NOTE: Codes must be kept in sync with DType.java.
// NOTE: Never serialize these, because they can change between releases.
constexpr static int kTensorDTypeUInt8 = 1;
constexpr static int kTensorDTypeInt8 = 2;
constexpr static int kTensorDTypeInt32 = 3;
constexpr static int kTensorDTypeFloat32 = 4;
constexpr static int kTensorDTypeInt64 = 5;
constexpr static int kTensorDTypeFloat64 = 6;
constexpr static int kTensorMemoryFormatContiguous = 1;
constexpr static int kTensorMemoryFormatChannelsLast = 2;
constexpr static int kTensorMemoryFormatChannelsLast3d = 3;
template <typename K = jobject, typename V = jobject>
struct JHashMap
: facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>> {
constexpr static auto kJavaDescriptor = "Ljava/util/HashMap;";
using Super =
facebook::jni::JavaClass<JHashMap<K, V>, facebook::jni::JMap<K, V>>;
static facebook::jni::local_ref<JHashMap<K, V>> create() {
return Super::newInstance();
}
void put(
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> key,
facebook::jni::alias_ref<facebook::jni::JObject::javaobject> value) {
static auto putMethod =
Super::javaClassStatic()
->template getMethod<facebook::jni::alias_ref<
facebook::jni::JObject::javaobject>(
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>,
facebook::jni::alias_ref<facebook::jni::JObject::javaobject>)>(
"put");
putMethod(Super::self(), key, value);
}
};
static at::Tensor newAtTensor(
facebook::jni::alias_ref<facebook::jni::JBuffer> jbuffer,
facebook::jni::alias_ref<jlongArray> jshape,
jint jdtype,
jint jmemoryFormat) {
const auto rank = jshape->size();
const auto shapeArr = jshape->getRegion(0, rank);
std::vector<int64_t> shapeVec{};
shapeVec.reserve(rank);
auto numel = 1;
for (const auto i : c10::irange(rank)) {
shapeVec.push_back(shapeArr[i]);
numel *= shapeArr[i];
}
JNIEnv* jni = facebook::jni::Environment::current();
caffe2::TypeMeta typeMeta{};
int dataElementSizeBytes = 0;
if (kTensorDTypeFloat32 == jdtype) {
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<float>();
} else if (kTensorDTypeInt32 == jdtype) {
dataElementSizeBytes = 4;
typeMeta = caffe2::TypeMeta::Make<int32_t>();
} else if (kTensorDTypeInt8 == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<int8_t>();
} else if (kTensorDTypeUInt8 == jdtype) {
dataElementSizeBytes = 1;
typeMeta = caffe2::TypeMeta::Make<uint8_t>();
} else if (kTensorDTypeFloat64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<double>();
} else if (kTensorDTypeInt64 == jdtype) {
dataElementSizeBytes = 8;
typeMeta = caffe2::TypeMeta::Make<int64_t>();
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown Tensor jdtype %d",
jdtype);
}
const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
if (dataCapacity != numel) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Tensor dimensions(elements number:%d, element byte size:%d, total "
"bytes:%d) inconsistent with buffer capacity(%d)",
numel,
dataElementSizeBytes,
numel * dataElementSizeBytes,
dataCapacity);
}
if (jmemoryFormat == kTensorMemoryFormatChannelsLast) {
auto sizes = torch::IntArrayRef(shapeVec);
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
sizes,
torch::IntArrayRef(c10::get_channels_last_strides_2d(sizes)),
at::TensorOptions(typeMeta).memory_format(
at::MemoryFormat::ChannelsLast));
} else if (jmemoryFormat == kTensorMemoryFormatChannelsLast3d) {
auto sizes = torch::IntArrayRef(shapeVec);
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
sizes,
torch::IntArrayRef(c10::get_channels_last_strides_3d(sizes)),
at::TensorOptions(typeMeta).memory_format(
at::MemoryFormat::ChannelsLast3d));
}
return torch::from_blob(
jni->GetDirectBufferAddress(jbuffer.get()),
torch::IntArrayRef(shapeVec),
at::TensorOptions(typeMeta));
}
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;";
explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {}
static facebook::jni::local_ref<TensorHybrid::jhybriddata> initHybrid(
facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto jMethodDTypeCode = cls->getMethod<jint()>("dtypeJniCode");
static const auto jMethodMemoryFormatCode =
cls->getMethod<jint()>("memoryFormatJniCode");
static const auto jFieldShape = cls->getField<jlongArray>("shape");
static const auto jMethodGetDataBuffer = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
at::Tensor tensor = newAtTensor(
jMethodGetDataBuffer(jTensorThis),
jTensorThis->getFieldValue(jFieldShape),
jMethodDTypeCode(jTensorThis),
jMethodMemoryFormatCode(jTensorThis));
return makeCxxInstance(std::move(tensor));
}
static facebook::jni::local_ref<TensorHybrid::javaobject>
newJTensorFromAtTensor(const at::Tensor& input_tensor) {
// Java wrapper currently only supports contiguous tensors.
int jmemoryFormat = 0;
at::Tensor tensor{};
if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
tensor = input_tensor;
jmemoryFormat = kTensorMemoryFormatChannelsLast;
} else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
tensor = input_tensor;
jmemoryFormat = kTensorMemoryFormatChannelsLast3d;
} else {
tensor = input_tensor.contiguous();
jmemoryFormat = kTensorMemoryFormatContiguous;
}
const auto scalarType = tensor.scalar_type();
int jdtype = 0;
if (at::kFloat == scalarType) {
jdtype = kTensorDTypeFloat32;
} else if (at::kInt == scalarType) {
jdtype = kTensorDTypeInt32;
} else if (at::kByte == scalarType) {
jdtype = kTensorDTypeUInt8;
} else if (at::kChar == scalarType) {
jdtype = kTensorDTypeInt8;
} else if (at::kLong == scalarType) {
jdtype = kTensorDTypeInt64;
} else if (at::kDouble == scalarType) {
jdtype = kTensorDTypeFloat64;
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"at::Tensor scalar type %s is not supported on java side",
c10::toString(scalarType));
}
const auto& tensorShape = tensor.sizes();
std::vector<jlong> tensorShapeVec;
for (const auto& s : tensorShape) {
tensorShapeVec.push_back(s);
}
facebook::jni::local_ref<jlongArray> jTensorShape =
facebook::jni::make_long_array(tensorShapeVec.size());
jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());
static auto cls = TensorHybrid::javaClassStatic();
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
facebook::jni::JByteBuffer::wrapBytes(
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
static const auto jMethodNewTensor =
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
facebook::jni::alias_ref<jlongArray>,
jint,
jint,
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
return jMethodNewTensor(
cls,
jTensorBuffer,
jTensorShape,
jdtype,
jmemoryFormat,
makeCxxInstance(tensor));
}
static at::Tensor newAtTensorFromJTensor(
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
jint jdtype = dtypeMethod(jtensor);
static const auto memoryFormatMethod =
cls->getMethod<jint()>("memoryFormatJniCode");
jint jmemoryFormat = memoryFormatMethod(jtensor);
static const auto shapeField = cls->getField<jlongArray>("shape");
auto jshape = jtensor->getFieldValue(shapeField);
static auto dataBufferMethod = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
dataBufferMethod(jtensor);
return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat);
}
at::Tensor tensor() const {
return tensor_;
}
private:
friend HybridBase;
at::Tensor tensor_;
};
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromStringDict(
c10::Dict<c10::IValue, c10::IValue> dict) {
static auto jMethodDictStringKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictStringKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JString::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::make_jstring(pair.key().toStringRef()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictStringKey(JIValue::javaClassStatic(), jmap);
}
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromIntDict(
c10::Dict<c10::IValue, c10::IValue> dict) {
static auto jMethodDictLongKey =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>>)>(
"dictLongKeyFrom");
auto jmap = JHashMap<
facebook::jni::alias_ref<facebook::jni::JLong::javaobject>,
facebook::jni::alias_ref<JIValue::javaobject>>::create();
for (auto& pair : dict) {
jmap->put(
facebook::jni::JLong::valueOf(pair.key().toInt()),
JIValue::newJIValueFromAtIValue(pair.value()));
}
return jMethodDictLongKey(JIValue::javaClassStatic(), jmap);
}
facebook::jni::local_ref<JIValue> JIValue::newJIValueFromAtIValue(
const at::IValue& ivalue,
DictCallback stringDictCallback,
DictCallback intDictCallback) {
Trace _s{"jni::JIValue::newJIValueFromAtIValue"};
if (ivalue.isNone()) {
static auto jMethodOptionalNull =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>()>(
"optionalNull");
return jMethodOptionalNull(JIValue::javaClassStatic());
} else if (ivalue.isTensor()) {
static auto jMethodTensor =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::local_ref<TensorHybrid::javaobject>)>("from");
const auto& tensor = ivalue.toTensor();
return jMethodTensor(
JIValue::javaClassStatic(),
TensorHybrid::newJTensorFromAtTensor(tensor));
} else if (ivalue.isBool()) {
static auto jMethodBool =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jboolean)>(
"from");
return jMethodBool(JIValue::javaClassStatic(), ivalue.toBool());
} else if (ivalue.isInt()) {
static auto jMethodInt =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jlong)>("from");
return jMethodInt(JIValue::javaClassStatic(), ivalue.toInt());
} else if (ivalue.isDouble()) {
static auto jMethodDouble =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(jdouble)>(
"from");
return jMethodDouble(JIValue::javaClassStatic(), ivalue.toDouble());
} else if (ivalue.isString()) {
static auto jMethodString =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JString::javaobject>)>(
"from");
return jMethodString(
JIValue::javaClassStatic(),
facebook::jni::make_jstring(ivalue.toStringRef()));
} else if (ivalue.isTuple()) {
auto elementsVec = ivalue.toTupleRef().elements();
static auto jMethodTupleArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject>)>("tupleFrom");
auto jElementsArray =
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(
elementsVec.size());
auto index = 0;
for (const auto& e : elementsVec) {
(*jElementsArray)[index++] = JIValue::newJIValueFromAtIValue(e);
}
return jMethodTupleArr(JIValue::javaClassStatic(), jElementsArray);
} else if (ivalue.isBoolList()) {
auto list = ivalue.toBoolList();
static auto jMethodBoolListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jbooleanArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_boolean_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethodBoolListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isIntList()) {
auto list = ivalue.toIntList();
static auto jMethodLongListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jlongArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_long_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethodLongListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isDoubleList()) {
auto list = ivalue.toDoubleList();
static auto jMethoDoubleListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<jdoubleArray>)>("listFrom");
size_t n = list.size();
auto jArray = facebook::jni::make_double_array(n);
auto jArrayPinned = jArray->pin();
auto index = 0;
for (const auto& e : list) {
jArrayPinned[index++] = e;
}
return jMethoDoubleListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isTensorList()) {
auto list = ivalue.toTensorList();
static auto jMethodTensorListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
TensorHybrid::javaobject>::javaobject>)>("listFrom");
auto jArray =
facebook::jni::JArrayClass<TensorHybrid::javaobject>::newArray(
list.size());
auto index = 0;
for (const auto& e : list) {
(*jArray)[index++] = TensorHybrid::newJTensorFromAtTensor(e);
}
return jMethodTensorListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isList()) {
auto list = ivalue.toList();
static auto jMethodListArr =
JIValue::javaClassStatic()
->getStaticMethod<facebook::jni::local_ref<JIValue>(
facebook::jni::alias_ref<facebook::jni::JArrayClass<
JIValue::javaobject>::javaobject>)>("listFrom");
auto jArray =
facebook::jni::JArrayClass<JIValue::javaobject>::newArray(list.size());
auto index = 0;
for (const auto& e : list) {
(*jArray)[index++] = JIValue::newJIValueFromAtIValue(e);
}
return jMethodListArr(JIValue::javaClassStatic(), jArray);
} else if (ivalue.isGenericDict()) {
auto dict = ivalue.toGenericDict();
const auto keyType = dict.keyType();
if (!keyType) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown IValue-Dict key type");
}
if (*keyType == *c10::StringType::get()) {
return stringDictCallback(std::move(dict));
} else if (*keyType == *c10::IntType::get()) {
return intDictCallback(std::move(dict));
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unsupported IValue-Dict key type: %s",
keyType->str().c_str());
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unsupported IValue type %s",
ivalue.tagKind().c_str());
}
at::IValue JIValue::JIValueToAtIValue(
facebook::jni::alias_ref<JIValue> jivalue) {
Trace _s{"jni::JIValue::JIValueToAtIValue"};
static const auto typeCodeField =
JIValue::javaClassStatic()->getField<jint>("mTypeCode");
const auto typeCode = jivalue->getFieldValue(typeCodeField);
if (JIValue::kTypeCodeNull == typeCode) {
return at::IValue{};
} else if (JIValue::kTypeCodeTensor == typeCode) {
static const auto jMethodGetTensor =
JIValue::javaClassStatic()
->getMethod<facebook::jni::alias_ref<TensorHybrid::javaobject>()>(
"toTensor");
return TensorHybrid::newAtTensorFromJTensor(jMethodGetTensor(jivalue));
} else if (JIValue::kTypeCodeBool == typeCode) {
static const auto jMethodGetBool =
JIValue::javaClassStatic()->getMethod<jboolean()>("toBool");
// explicit cast to bool as jboolean is defined as uint8_t, IValue ctor
// for int will be called for jboolean
bool b = jMethodGetBool(jivalue);
return at::IValue{b};
} else if (JIValue::kTypeCodeLong == typeCode) {
static const auto jMethodGetLong =
JIValue::javaClassStatic()->getMethod<jlong()>("toLong");
return at::IValue{(int64_t)jMethodGetLong(jivalue)};
} else if (JIValue::kTypeCodeDouble == typeCode) {
static const auto jMethodGetDouble =
JIValue::javaClassStatic()->getMethod<jdouble()>("toDouble");
return at::IValue{jMethodGetDouble(jivalue)};
} else if (JIValue::kTypeCodeString == typeCode) {
static const auto jMethodGetString =
JIValue::javaClassStatic()->getMethod<jstring()>("toStr");
return at::IValue{jMethodGetString(jivalue)->toStdString()};
} else if (JIValue::kTypeCodeTuple == typeCode) {
static const auto jMethodGetTuple =
JIValue::javaClassStatic()
->getMethod<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
"toTuple");
auto jarray = jMethodGetTuple(jivalue);
size_t n = jarray->size();
std::vector<at::IValue> elements;
elements.reserve(n);
for (const auto i : c10::irange(n)) {
auto jivalue_element = jarray->getElement(i);
auto element = JIValue::JIValueToAtIValue(jivalue_element);
elements.push_back(std::move(element));
}
return c10::ivalue::Tuple::create(std::move(elements));
} else if (JIValue::kTypeCodeBoolList == typeCode) {
static const auto jMethodGetBoolList =
JIValue::javaClassStatic()->getMethod<jbooleanArray()>("toBoolList");
auto jArray = jMethodGetBoolList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<bool> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeLongList == typeCode) {
static const auto jMethodGetLongList =
JIValue::javaClassStatic()->getMethod<jlongArray()>("toLongList");
auto jArray = jMethodGetLongList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<int64_t> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeDoubleList == typeCode) {
static const auto jMethodGetDoubleList =
JIValue::javaClassStatic()->getMethod<jdoubleArray()>("toDoubleList");
auto jArray = jMethodGetDoubleList(jivalue);
auto jArrayPinned = jArray->pin();
size_t n = jArrayPinned.size();
c10::List<double> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(jArrayPinned[i]);
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeTensorList == typeCode) {
static const auto jMethodGetTensorList =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JArrayClass<
TensorHybrid::javaobject>::javaobject()>("toTensorList");
auto jArray = jMethodGetTensorList(jivalue);
size_t n = jArray->size();
c10::List<at::Tensor> list{};
list.reserve(n);
for (const auto i : c10::irange(n)) {
list.push_back(
TensorHybrid::newAtTensorFromJTensor(jArray->getElement(i)));
}
return at::IValue{std::move(list)};
} else if (JIValue::kTypeCodeList == typeCode) {
static const auto jMethodGetList =
JIValue::javaClassStatic()
->getMethod<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject()>(
"toList");
auto jarray = jMethodGetList(jivalue);
size_t n = jarray->size();
if (n == 0) {
return at::IValue{c10::impl::GenericList(c10::TensorType::get())};
}
auto jivalue_first_element = jarray->getElement(0);
auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element);
c10::impl::GenericList list{c10::unshapedType(first_element.type())};
list.reserve(n);
list.push_back(first_element);
for (const auto i : c10::irange(1, n)) {
auto jivalue_element = jarray->getElement(i);
auto element = JIValue::JIValueToAtIValue(jivalue_element);
list.push_back(element);
}
return at::IValue{list};
} else if (JIValue::kTypeCodeDictStringKey == typeCode) {
static const auto jMethodGetDictStringKey =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JMap<jstring, JIValue::javaobject>::
javaobject()>("toDictStringKey");
auto jmap = jMethodGetDictStringKey(jivalue);
auto it = jmap->begin();
if (it == jmap->end()) {
return at::IValue{c10::impl::GenericDict(
c10::StringType::get(), c10::TensorType::get())};
}
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
c10::impl::GenericDict dict{
c10::StringType::get(), c10::unshapedType(firstEntryValue.type())};
dict.insert(it->first->toStdString(), firstEntryValue);
it++;
for (; it != jmap->end(); it++) {
dict.insert(
it->first->toStdString(), JIValue::JIValueToAtIValue(it->second));
}
return at::IValue{dict};
} else if (JIValue::kTypeCodeDictLongKey == typeCode) {
static const auto jMethodGetDictLongKey =
JIValue::javaClassStatic()
->getMethod<facebook::jni::JMap<
facebook::jni::JLong::javaobject,
JIValue::javaobject>::javaobject()>("toDictLongKey");
auto jmap = jMethodGetDictLongKey(jivalue);
auto it = jmap->begin();
if (it == jmap->end()) {
return at::IValue{
c10::impl::GenericDict(c10::IntType::get(), c10::TensorType::get())};
}
auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
c10::impl::GenericDict dict{
c10::IntType::get(), c10::unshapedType(firstEntryValue.type())};
dict.insert((int64_t)it->first->longValue(), firstEntryValue);
it++;
for (; it != jmap->end(); it++) {
dict.insert(
(int64_t)it->first->longValue(),
JIValue::JIValueToAtIValue(it->second));
}
return at::IValue{dict};
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unknown IValue typeCode %d",
typeCode);
}
#if defined(__ANDROID__)
class PyTorchAndroidJni : public facebook::jni::JavaClass<PyTorchAndroidJni> {
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/PyTorchAndroid;";
static void registerNatives() {
javaClassStatic()->registerNatives({
makeNativeMethod(
"nativeSetNumThreads", PyTorchAndroidJni::setNumThreads),
});
}
static void setNumThreads(facebook::jni::alias_ref<jclass>, jint numThreads) {
caffe2::pthreadpool()->set_thread_count(numThreads);
}
};
#endif
void common_registerNatives() {
static const int once = []() {
#if defined(__ANDROID__)
pytorch_jni::PyTorchAndroidJni::registerNatives();
#endif
return 0;
}();
((void)once);
}
} // namespace pytorch_jni

View File

@ -1,137 +0,0 @@
#pragma once
#include <c10/util/FunctionRef.h>
#include <fbjni/fbjni.h>
#include <torch/csrc/api/include/torch/types.h>
#include "caffe2/serialize/read_adapter_interface.h"
#include "cmake_macros.h"
#ifdef __ANDROID__
#include <android/log.h>
#define ALOGI(...) \
__android_log_print(ANDROID_LOG_INFO, "pytorch-jni", __VA_ARGS__)
#define ALOGE(...) \
__android_log_print(ANDROID_LOG_ERROR, "pytorch-jni", __VA_ARGS__)
#endif
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
#include <android/trace.h>
#include <dlfcn.h>
#endif
namespace pytorch_jni {
constexpr static int kDeviceCPU = 1;
constexpr static int kDeviceVulkan = 2;
c10::DeviceType deviceJniCodeToDeviceType(jint deviceJniCode);
class Trace {
public:
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
typedef void* (*fp_ATrace_beginSection)(const char* sectionName);
typedef void* (*fp_ATrace_endSection)(void);
static fp_ATrace_beginSection ATrace_beginSection;
static fp_ATrace_endSection ATrace_endSection;
#endif
static void ensureInit() {
if (!Trace::is_initialized_) {
init();
Trace::is_initialized_ = true;
}
}
static void beginSection(const char* name) {
Trace::ensureInit();
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
ATrace_beginSection(name);
#endif
}
static void endSection() {
#if defined(TRACE_ENABLED) && defined(__ANDROID__)
ATrace_endSection();
#endif
}
Trace(const char* name) {
ensureInit();
beginSection(name);
}
~Trace() {
endSection();
}
private:
static void init();
static bool is_initialized_;
};
class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
public:
explicit MemoryReadAdapter(const void* data, off_t size)
: data_(data), size_(size){};
size_t size() const override {
return size_;
}
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
memcpy(buf, (int8_t*)(data_) + pos, n);
return n;
}
~MemoryReadAdapter() {}
private:
const void* data_;
off_t size_;
};
class JIValue : public facebook::jni::JavaClass<JIValue> {
using DictCallback = c10::function_ref<facebook::jni::local_ref<JIValue>(
c10::Dict<c10::IValue, c10::IValue>)>;
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/IValue;";
constexpr static int kTypeCodeNull = 1;
constexpr static int kTypeCodeTensor = 2;
constexpr static int kTypeCodeBool = 3;
constexpr static int kTypeCodeLong = 4;
constexpr static int kTypeCodeDouble = 5;
constexpr static int kTypeCodeString = 6;
constexpr static int kTypeCodeTuple = 7;
constexpr static int kTypeCodeBoolList = 8;
constexpr static int kTypeCodeLongList = 9;
constexpr static int kTypeCodeDoubleList = 10;
constexpr static int kTypeCodeTensorList = 11;
constexpr static int kTypeCodeList = 12;
constexpr static int kTypeCodeDictStringKey = 13;
constexpr static int kTypeCodeDictLongKey = 14;
static facebook::jni::local_ref<JIValue> newJIValueFromAtIValue(
const at::IValue& ivalue,
DictCallback stringDictCallback = newJIValueFromStringDict,
DictCallback intDictCallback = newJIValueFromIntDict);
static at::IValue JIValueToAtIValue(
facebook::jni::alias_ref<JIValue> jivalue);
private:
static facebook::jni::local_ref<JIValue> newJIValueFromStringDict(
c10::Dict<c10::IValue, c10::IValue>);
static facebook::jni::local_ref<JIValue> newJIValueFromIntDict(
c10::Dict<c10::IValue, c10::IValue>);
};
void common_registerNatives();
} // namespace pytorch_jni

View File

@ -1,245 +0,0 @@
#include <cassert>
#include <iostream>
#include <memory>
#include <string>
#include <fbjni/ByteBuffer.h>
#include <fbjni/fbjni.h>
#include <ATen/record_function.h>
#include <torch/csrc/jit/runtime/print_handler.h>
#include <torch/script.h>
#include "caffe2/serialize/read_adapter_interface.h"
#include "pytorch_jni_common.h"
#ifdef __ANDROID__
#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
#include <android/log.h>
#endif
namespace pytorch_jni {
namespace {
struct JITCallGuard {
// Inference only workload.
c10::InferenceMode guard;
// Disable graph optimizer to ensure list of unused ops are not changed for
// custom mobile build.
torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
};
} // namespace
class PytorchJni : public facebook::jni::HybridClass<PytorchJni> {
private:
friend HybridBase;
torch::jit::Module module_;
c10::DeviceType deviceType_;
public:
constexpr static auto kJavaDescriptor = "Lorg/pytorch/NativePeer;";
static facebook::jni::local_ref<jhybriddata> initHybrid(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
return makeCxxInstance(modelPath, extraFiles, device);
}
#ifdef __ANDROID__
static facebook::jni::local_ref<jhybriddata> initHybridAndroidAsset(
facebook::jni::alias_ref<jclass>,
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
return makeCxxInstance(assetName, assetManager, device);
}
#endif
#ifdef TRACE_ENABLED
static std::unique_ptr<at::ObserverContext> onFunctionEnter(
const at::RecordFunction& fn) {
Trace::beginSection(fn.name().str());
return nullptr;
}
static void onFunctionExit(const at::RecordFunction&, at::ObserverContext*) {
Trace::endSection();
}
#endif
static void preModuleLoadSetupOnce() {
auto qengines = at::globalContext().supportedQEngines();
if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) !=
qengines.end()) {
at::globalContext().setQEngine(at::QEngine::QNNPACK);
}
#ifdef __ANDROID__
torch::jit::setPrintHandler([](const std::string& s) {
__android_log_print(ANDROID_LOG_DEBUG, "pytorch-print", "%s", s.c_str());
});
#endif
#ifdef TRACE_ENABLED
at::addGlobalCallback(
at::RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
.scopes({RecordScope::FUNCTION, RecordScope::USER_SCOPE}));
#endif
}
void preModuleLoadSetup() {
static const int once = []() {
preModuleLoadSetupOnce();
return 0;
}();
((void)once);
}
PytorchJni(
facebook::jni::alias_ref<jstring> modelPath,
facebook::jni::alias_ref<
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
extraFiles,
jint device) {
preModuleLoadSetup();
JITCallGuard guard;
std::unordered_map<std::string, std::string> extra_files;
const auto has_extra = extraFiles && extraFiles->size() > 0;
if (has_extra) {
for (const auto& e : *extraFiles) {
extra_files[e.first->toStdString()] = "";
}
}
deviceType_ = deviceJniCodeToDeviceType(device);
module_ = torch::jit::load(
std::move(modelPath->toStdString()), std::nullopt, extra_files);
if (has_extra) {
static auto putMethod =
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>::
javaClassStatic()
->template getMethod<facebook::jni::alias_ref<jobject>(
facebook::jni::alias_ref<jobject>,
facebook::jni::alias_ref<jobject>)>("put");
for (const auto& ef : extra_files) {
putMethod(
extraFiles,
facebook::jni::make_jstring(ef.first),
facebook::jni::make_jstring(ef.second));
}
}
module_.eval();
}
#ifdef __ANDROID__
PytorchJni(
facebook::jni::alias_ref<jstring> assetName,
facebook::jni::alias_ref<jobject> assetManager,
jint device) {
preModuleLoadSetup();
JNIEnv* env = facebook::jni::Environment::current();
AAssetManager* mgr = AAssetManager_fromJava(env, assetManager.get());
if (!mgr) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Unable to get asset manager");
}
AAsset* asset = AAssetManager_open(
mgr, assetName->toStdString().c_str(), AASSET_MODE_BUFFER);
if (!asset) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Failed to open asset '%s'",
assetName->toStdString().c_str());
}
auto assetBuffer = AAsset_getBuffer(asset);
if (!assetBuffer) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Could not get buffer for asset '%s'",
assetName->toStdString().c_str());
}
JITCallGuard guard;
module_ = torch::jit::load(std::make_unique<MemoryReadAdapter>(
assetBuffer, AAsset_getLength(asset)));
AAsset_close(asset);
module_.eval();
deviceType_ = deviceJniCodeToDeviceType(device);
}
#endif
static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", PytorchJni::initHybrid),
#ifdef __ANDROID__
makeNativeMethod(
"initHybridAndroidAsset", PytorchJni::initHybridAndroidAsset),
#endif
makeNativeMethod("forward", PytorchJni::forward),
makeNativeMethod("runMethod", PytorchJni::runMethod),
});
}
facebook::jni::local_ref<JIValue> forward(
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
Trace _s{"jni::Module::forward"};
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
auto output = [&]() {
JITCallGuard guard;
return module_.forward(std::move(inputs));
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::local_ref<JIValue> runMethod(
facebook::jni::alias_ref<facebook::jni::JString::javaobject> jmethodName,
facebook::jni::alias_ref<
facebook::jni::JArrayClass<JIValue::javaobject>::javaobject>
jinputs) {
std::string methodName = jmethodName->toStdString();
std::vector<at::IValue> inputs{};
size_t n = jinputs->size();
inputs.reserve(n);
for (const auto i : c10::irange(n)) {
at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i));
inputs.push_back(std::move(atIValue));
}
if (auto method = module_.find_method(methodName)) {
auto output = [&]() {
JITCallGuard guard;
return (*method)(std::move(inputs));
}();
return JIValue::newJIValueFromAtIValue(output);
}
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Undefined method %s",
methodName.c_str());
}
};
} // namespace pytorch_jni
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
pytorch_jni::common_registerNatives();
pytorch_jni::PytorchJni::registerNatives();
});
}

Some files were not shown because too many files have changed in this diff Show More