Compare commits

...

5 Commits

Author SHA1 Message Date
664058fa83 Pin windows numpy (#82652) (#82686)
### Description
Pinned numpy for windows temporarily. Numpy shouldn't need to be pinned, but its upgrade to 1.23 has wreaked havoc on CI. Never mind about the not pinning!

### Issue
Related to https://github.com/pytorch/pytorch/issues/82653

### Testing
CI -- win cpu should pass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82652
Approved by: https://github.com/malfet

Co-authored-by: Jane Xu <janeyx@fb.com>
2022-08-02 19:50:01 -04:00
efc2d08eac Revert #75195 (#82504) (#82662)
This is a short-term fix for a serious regression in functorch
(https://github.com/pytorch/functorch/issues/989).

Additional things this PR does:
- the out= tests for nn.functional.linear fail after the revert. I added
some xfails. These xfails were present in the original PR (#75195).
- the profiler tests fail on the revert, so I updated the expecttests
for the profiler tests

Test Plan:
- test offline that the functorch regression was fixed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82504
Approved by: https://github.com/ngimel, https://github.com/ezyang, https://github.com/atalman
2022-08-02 18:17:27 -04:00
9a9dcebfd5 ONNX cherry picks for 1.12.1 (#82435)
* [ONNX] Variable length argument support for quantized_args (#78775)

Add support for decorating functions with variable length arguments in `quantized_args`. This is needed to decorate functions like `symbolic_fn` in `_interpolate_helper` which takes `*args`.

Previously it is not possible to decorate functions like it. Now we can do

```python
@quantized_args(True)
def symbolic_fn(g, input, output_size, *args):
    ...
```

and the rest of the params are defaulted to non-quantized.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78775
Approved by: https://github.com/garymm

* [ONNX] Quantization support for five ops (#78103)

- Add quantization support for `interpolate`, `avgpool`, `sigmoid` and `add_relu`
- Return the inputs to ListUnpack if the previous node is ListConstruct so that `ListConstruct` and `ListUnpack` are canceled and removed in the jit passes. ONNX doesn't support them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78103
Approved by: https://github.com/garymm

* [ONNX] Add quantization support to _avg_pool opset 9 and clean up (#79793)

- Add quantization support to _avg_pool opset 9
- Clean up reused / unused variables in avgpool helper
- Add types
- Sort `__all__`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79793
Approved by: https://github.com/BowenBao

* [ONNX] Quantization support for quantized::cat (#79826)

- Add support for quantized `cat`
- Add type annotations for helper functions

Now we can export

```python
import torchvision.models.quantization as models
from torchvision import transforms

torch_model = models.inception_v3(pretrained=True, quantize=True)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79826
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao

* [torch] Add more functions to __init__.pyi.in for torch._C for Node and Value (#79654)

Summary:
https://github.com/pytorch/pytorch/pull/78757 recently added
a lot of functions to the type stub, but it missed a few of them.

This change will make sure every function is included, by making
sure this list is up-to-date with: `torch/csrc/jit/python/python_ir.cpp`.

This change only does this for Node and Value.

Differential Revision: D37189713

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79654
Approved by: https://github.com/ezyang

* [ONNX] Test and address autograd.func (FusedLayerNorm) shape inference (#81931)

This PR addresses the ONNX exporter issue of wrongly inferred static shape by unreliable nodes:

1. Specifically, this unblocks the usage of apex `FusedLayerNorm` (autograd.function) in transformer. Before this PR, the downstream nodes of apex `FusedLayerNorm` are inferred with static shape even though they are unreliable (should be dynamic).

2. Add a general test case using autograd function to wrap `torch.nn.layernorm` which can repro the same issue as apex `FusedLayerNorm` did in transformers-embedding layer.

3. Remove a legacy test `test_empty_like_opset7` which still uses deprecated ConstantFill op. As this node is not supported by onnx (checker) anymore, the output of its shape inference leading to unexpected outcome, and is exposed by this PR.
```python
Warning: Checker does not support models with experimental ops: ConstantFill
```

Please advise if there is a better place for the test case.
Fixes #82330
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81931
Approved by: https://github.com/justinchuby, https://github.com/BowenBao

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Justin Chu <justinchu@microsoft.com>
Co-authored-by: Riley Dulin <dulinr@fb.com>
Co-authored-by: titaiwang <titaiwang@microsoft.com>
2022-08-02 13:44:32 -04:00
617c4fe52c Fix invalid read in masked softmax (#82272) (#82272) (#82405)
Summary:
PEr title, unfortunately testing invalid reads with caching allocator is hard.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82272
Approved by: https://github.com/cpuhrsch

Test Plan:
contbuild & OSS CI, see 24d702d38e

Original Phabricator Test Plan:
Imported from GitHub, without a `Test Plan:` line.

Reviewed By: ajtulloch, osalpekar, cpuhrsch

Differential Revision: D38183160

Pulled By: ngimel

fbshipit-source-id: 0ea59868d4829bc540c1277a93daa029519d05b4

Co-authored-by: Natalia Gimelshein (Meta Employee) <ngimel@fb.com>
2022-07-28 13:08:39 -04:00
f469bc1fe1 [ci] Release only change: bump macos worker instance type (#82113)
* [ci] Release only change: bump macos worker instance type

* Applying bump for nightly

* Add macos-12-xl to actionlint
2022-07-25 18:22:51 +01:00
23 changed files with 476 additions and 174 deletions

View File

@ -12,5 +12,6 @@ self-hosted-runner:
- windows.8xlarge.nvidia.gpu
- bm-runner
- linux.rocm.gpu
- macos-12-xl
- macos-12
- macos12.3-m1

View File

@ -64,7 +64,7 @@ jobs:
{%- if config["package_type"] == "libtorch" %}
runs-on: macos-10.15
{%- else %}
runs-on: macos-12
runs-on: macos-12-xl
{%- endif %}
{%- if config["package_type"] == "libtorch" %}
# libtorch builds take a long time on github hosted runners

View File

@ -39,7 +39,7 @@ concurrency:
jobs:
conda-py3_8-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -213,7 +213,7 @@ jobs:
docker system prune -af
conda-py3_9-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -387,7 +387,7 @@ jobs:
docker system prune -af
conda-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -39,7 +39,7 @@ concurrency:
jobs:
wheel-py3_7-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -213,7 +213,7 @@ jobs:
docker system prune -af
wheel-py3_8-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -387,7 +387,7 @@ jobs:
docker system prune -af
wheel-py3_9-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -561,7 +561,7 @@ jobs:
docker system prune -af
wheel-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -37,7 +37,7 @@ concurrency:
jobs:
conda-py3_7-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -211,7 +211,7 @@ jobs:
docker system prune -af
conda-py3_8-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -385,7 +385,7 @@ jobs:
docker system prune -af
conda-py3_9-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -559,7 +559,7 @@ jobs:
docker system prune -af
conda-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -37,7 +37,7 @@ concurrency:
jobs:
wheel-py3_7-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -211,7 +211,7 @@ jobs:
docker system prune -af
wheel-py3_8-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -385,7 +385,7 @@ jobs:
docker system prune -af
wheel-py3_9-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
@ -559,7 +559,7 @@ jobs:
docker system prune -af
wheel-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: macos-12
runs-on: macos-12-xl
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch

View File

@ -19,7 +19,7 @@ if "%INSTALL_FRESH_CONDA%"=="1" (
call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3
if "%INSTALL_FRESH_CONDA%"=="1" (
call conda install -y -q python=%PYTHON_VERSION% numpy cffi pyyaml boto3 libuv
call conda install -y -q python=%PYTHON_VERSION% numpy"<1.23" cffi pyyaml boto3 libuv
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
call conda install -y -q -c conda-forge cmake=3.22.3

View File

@ -1712,7 +1712,8 @@ Tensor _matmul_impl(
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
return has_out ? at::mv_out(out, tensor2.t(), tensor1) : tensor2.t().mv(tensor1);
return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
: tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
} else if (should_fold(tensor1, dim_tensor2) || should_fold(tensor2, dim_tensor1)) {

View File

@ -123,12 +123,14 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
for (int it = 0; it < WARP_ITERATIONS; ++it) {
if (is_masked) {
int idx = it*WARP_SIZE;
if (!is_transformer_mask) {
idx += i*element_count;
}
if (!mask[idx]) {
max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
is_meaningful_max = true;
if ((idx + local_idx) < element_count) {
if (!is_transformer_mask) {
idx += i*element_count;
}
if (!mask[idx]) {
max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
is_meaningful_max = true;
}
}
} else {
max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
@ -156,22 +158,28 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
}
} else {
int idx = it*WARP_SIZE;
bool valid = (idx + local_idx) < element_count;
if (!is_transformer_mask) {
idx += i*element_count;
}
if (!mask[idx]) {
if (is_log_softmax) {
sum[i] += std::exp(elements[i][it] - max_value[i]);
if (valid) {
if (!mask[idx]) {
if (is_log_softmax) {
sum[i] += std::exp(elements[i][it] - max_value[i]);
} else {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
}
} else {
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
sum[i] += elements[i][it];
if (!is_log_softmax) {
// Masked values are treated as -infinity, and std::exp(-infinity) is 0.
elements[i][it] = 0;
}
}
} else {
if (!is_log_softmax) {
// Masked values are treated as -infinity, and std::exp(-infinity) is 0.
elements[i][it] = 0;
}
if (!is_log_softmax) {
elements[i][it] = 0.;
}
}
}
}

View File

@ -1,68 +0,0 @@
ir_version: 3
producer_name: "pytorch"
producer_version: "CURRENT_VERSION"
graph {
node {
input: "onnx::Shape_0"
output: "onnx::ConstantFill_1"
name: "Shape_0"
op_type: "Shape"
}
node {
input: "onnx::ConstantFill_1"
output: "2"
name: "ConstantFill_1"
op_type: "ConstantFill"
attribute {
name: "dtype"
i: 1
type: INT
}
attribute {
name: "input_as_shape"
i: 1
type: INT
}
attribute {
name: "value"
f: 0
type: FLOAT
}
}
name: "torch_jit"
input {
name: "onnx::Shape_0"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 8
}
}
}
}
}
output {
name: "2"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 5
}
dim {
dim_value: 8
}
}
}
}
}
}
opset_import {
version: 7
}

View File

@ -0,0 +1,32 @@
# Owner(s): ["module: onnx"]
"""Tests for `torch.onnx.symbolic_opset9`."""
import torch
from torch import _C
from torch.onnx import symbolic_opset9 as opset9
from torch.testing._internal import common_utils
class TestPrim(common_utils.TestCase):
def setUp(self):
super().setUp()
self.graph = _C.Graph()
def test_list_unpack_returns_all_list_elements_when_previous_node_is_list_construct(
self,
):
# Build the graph
input_1 = self.graph.addInput()
input_1.setType(input_1.type().with_dtype(torch.float).with_sizes([2, 42]))
input_2 = self.graph.addInput()
input_2.setType(input_2.type().with_dtype(torch.float).with_sizes([3, 42]))
constructed_list = self.graph.op("prim::ListConstruct", input_1, input_2)
# Test the op
outputs = opset9.Prim.ListUnpack(self.graph, constructed_list)
self.assertNotEqual(outputs, None)
self.assertEqual(outputs[0], input_1)
self.assertEqual(outputs[1], input_2)
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -737,10 +737,6 @@ class TestOperators(TestCase):
x = torch.randn(5, 8, requires_grad=True)
self.assertONNX(lambda x: torch.empty_like(x), x)
def test_empty_like_opset7(self):
x = torch.randn(5, 8, requires_grad=True)
self.assertONNX(lambda x: torch.empty_like(x), x, opset_version=7)
def test_zeros_like(self):
x = torch.randn(5, 8, requires_grad=True)
self.assertONNX(lambda x: torch.zeros_like(x), x)

View File

@ -7,6 +7,7 @@ import unittest
from typing import Optional, Type
import onnx
import onnx.numpy_helper
import torch
from torch import Tensor
@ -106,9 +107,75 @@ class TestOptionalOutput(unittest.TestCase):
input_names=["y"],
)
def test_maintain_dynamic_shapes_of_unreliable_nodes(self):
def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs):
return g.op("com.microsoft::PythonOp")
torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1)
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1)
# necessay parameters for transformer embeddings
hidden_size = 48
max_position_embeddings = 32
batch_size = 2
# issue found that autograd.function making downstream
# node unreliable but with static shape. The issue was first
# discovered with using Apex FusedLayerNorm in Transformers
class CustomLayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, embedding):
layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12)
return layer_norm(embedding)
class EmbeddingModule(torch.nn.Module):
def forward(
self,
embeddings=None,
):
embedding_output = CustomLayerNorm.apply(embeddings)
query = embedding_output.transpose(0, 1)
target_len, batch_size, embedding_dim = query.size()
# Reshape is used for consuming batch_size, and if it is static,
# this will be a Constant node in the graph
query = query.reshape(target_len, batch_size, embedding_dim)
return query
embeddings = torch.randn(batch_size, max_position_embeddings, hidden_size)
f = io.BytesIO()
torch.onnx.export(
EmbeddingModule().eval(),
(embeddings,),
f,
input_names=["embeddings"],
dynamic_axes={
"embeddings": {
0: "batch_size",
1: "max_position_embeddings",
2: "hidden_size",
}
},
custom_opsets={"com.microsoft": 1},
)
model = onnx.load(io.BytesIO(f.getvalue()))
# If there is a constant node with dim=3 and max_position_embeddings,
# batch_size, hidden_size as shape, it means the shape becomes static.
# Normally, with dynamic batch size, this constant node should not exist.
const_node = [n for n in model.graph.node if n.op_type == "Constant"]
self.assertNotEqual(len(const_node), 0)
for node in const_node:
for a in node.attribute:
if a.name == "value":
shape = onnx.numpy_helper.to_array(a.t)
self.assertNotEqual(
shape.tolist(),
[max_position_embeddings, batch_size, hidden_size],
)
instantiate_parametrized_tests(TestOptionalOutput)
if __name__ == "__main__":
unittest.main()

View File

@ -3456,28 +3456,68 @@ class _TestONNXRuntime:
self.run_test(model, x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_listunpack(self):
class ListUnpack(torch.jit.ScriptModule):
@torch.jit.script_method
def test_list_unpack_scripted(self):
class ListUnpack(torch.nn.Module):
def forward(self, x):
a, b = x.shape
return x.new_zeros((a, b))
x = torch.randn(2, 3)
self.run_test(ListUnpack(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
self.run_test(ListUnpack(), x, remained_onnx_input_idx=[])
self.run_test(
torch.jit.script(ListUnpack()),
x,
input_names=["x"],
dynamic_axes={"x": [0, 1]},
)
self.run_test(torch.jit.script(ListUnpack()), x, remained_onnx_input_idx=[])
class ListUnpackSlice(torch.jit.ScriptModule):
@torch.jit.script_method
@skipIfUnsupportedMinOpsetVersion(9)
def test_list_unpack_scripted_runs_without_error_with_constructed_list_as_input(
self,
):
class PackUnpack(torch.nn.Module):
"""Create and unpack a list of tensors.
When scripted, it should produce a graph similar to
```
graph(%self : __torch__.PackUnpack,
%a.1 : Tensor,
%b.1 : Tensor):
%packed.1 : Tensor[] = prim::ListConstruct(%a.1, %b.1)
%c.1 : Tensor, %8 : Tensor = prim::ListUnpack(%packed.1)
return (%c.1)
```
"""
def forward(self, a, b):
packed = [a, b]
c, _ = packed
return c
self.run_test(
torch.jit.script(PackUnpack()),
(torch.tensor(0), torch.tensor([42])),
remained_onnx_input_idx=[0],
)
@skipIfUnsupportedMinOpsetVersion(9)
def test_list_unpack_slice_scripted(self):
class ListUnpackSlice(torch.nn.Module):
def forward(self, x):
a, b = x.shape[2:]
return x.new_zeros((a, b))
x = torch.randn(2, 3, 4, 5)
self.run_test(
ListUnpackSlice(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}
torch.jit.script(ListUnpackSlice()),
x,
input_names=["x"],
dynamic_axes={"x": [0, 1, 2, 3]},
)
self.run_test(
torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
)
self.run_test(ListUnpackSlice(), x, remained_onnx_input_idx=[])
def test_pow(self):
class PowModule(torch.nn.Module):
@ -12366,6 +12406,13 @@ class _TestONNXRuntime:
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
self.run_test(model, q_input)
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_sigmoid(self):
model = torch.nn.Sigmoid()
input = torch.randn(2, 6)
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
self.run_test(model, q_input)
@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_flatten(self):
class FlattenModel(torch.nn.Module):
@ -12375,6 +12422,19 @@ class _TestONNXRuntime:
x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
self.run_test(FlattenModel(), x)
@unittest.skip(
"ONNX Runtime 1.11 does not support quantized cat. Enable after ORT 1.12 is enabled in CI."
)
@skipIfUnsupportedMinOpsetVersion(10)
@skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
def test_quantized_cat(self):
class QuantizedConcatenationModel(torch.nn.Module):
def forward(self, x):
return torch.nn.quantized.QFunctional().cat((x, x), dim=1)
q_input = torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 128, torch.quint8)
self.run_test(QuantizedConcatenationModel(), q_input)
@skipIfUnsupportedMinOpsetVersion(10)
@skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
def test_quantized_arithmetic_qfunctional(self):
@ -12596,6 +12656,32 @@ class _TestONNXRuntime:
input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
self.run_test(model, input)
@skipIfUnsupportedMinOpsetVersion(10)
def test_qat_avg_pool2d(self):
model = torch.nn.Sequential(
torch.quantization.QuantStub(),
torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
torch.quantization.DeQuantStub(),
)
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
model = torch.quantization.prepare_qat(model.train())
model = torch.quantization.convert(model)
input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
self.run_test(model, input)
@skipIfUnsupportedMinOpsetVersion(11)
def test_qat_upsample_nearest2d(self):
model = torch.nn.Sequential(
torch.quantization.QuantStub(),
torch.nn.UpsamplingNearest2d(scale_factor=1.5),
torch.quantization.DeQuantStub(),
)
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
model = torch.quantization.prepare_qat(model.train())
model = torch.quantization.convert(model)
input = _construct_tensor_for_quantization_test((4, 3, 2, 2))
self.run_test(model, input)
@skipIfUnsupportedMinOpsetVersion(9)
def test_convolution_allow_tf32(self):
class Module(torch.nn.Module):

View File

@ -451,11 +451,34 @@ class _InsertPoint:
def __enter__(self) -> None: ...
def __exit__(self, *args) -> None: ...
# Defined in torch/csrc/jit/ir/ir.h
class Use:
@property
def user(self) -> Node: ...
@property
def offset(self) -> _int: ...
def isAfter(self, other: Use) -> _bool: ...
...
# Defined in torch/csrc/jit/ir/ir.h
class Value:
def type(self)-> JitType: ...
def setType(self, t: JitType) -> Value: ...
def setTypeAs(self, other: Value) -> Value: ...
def inferTypeFrom(self, t: Tensor) -> None: ...
def debugName(self) -> str: ...
def setDebugName(self, name: str) -> None: ...
def unique(self) -> _int: ...
def offset(self) -> _int: ...
def node(self) -> Node: ...
def uses(self) -> List[Use]: ...
def replaceAllUsesWith(self, val: Value) -> None: ...
def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ...
def requires_grad(self) -> _bool: ...
def requiresGrad(self) -> _bool: ...
def copyMetadata(self, other: Value) -> Value: ...
def isCompleteTensor(self) -> _bool: ...
def toIValue(self) -> IValue: ...
...
# Defined in torch/csrc/jit/ir/ir.h
@ -467,14 +490,80 @@ class Block:
# Defined in torch/csrc/jit/ir/ir.h
class Node:
def schema(self) -> str: ...
def input(self) -> Value: ...
def inputs(self) -> List[Value]: ...
def inputsAt(self, idx: _int) -> Value: ...
def inputsSize(self) -> _int: ...
def output(self) -> Value: ...
def outputs(self) -> List[Value]: ...
def outputsAt(self, idx: _int) -> Value: ...
def outputsSize(self) -> _int: ...
def hasMultipleOutputs(self) -> _bool: ...
def blocks(self) -> List[Block]: ...
def mustBeNone(self) -> _bool: ...
def kindOf(self, str) -> str: ...
def __getitem__(self, key: str) -> Any: ...
def namedInput(self, str) -> Value: ...
def matches(self, pattern: str) -> _bool: ...
def kind(self) -> str: ...
def kindOf(self, name: str) -> str: ...
def addInput(self, name: str) -> Value: ...
def replaceInput(self, i: _int, newValue: Value) -> Value: ...
def replaceInputWith(self, from_: Value, to: Value) -> None: ...
def replaceAllUsesWith(self, n: Node) -> None: ...
def insertBefore(self, n: Node) -> Node: ...
def insertAfter(self, n: Node) -> Node: ...
def isBefore(self, n: Node) -> _bool: ...
def isAfter(self, n: Node) -> _bool: ...
def moveBefore(self, n: Node) -> None: ...
def moveAfter(self, n: Node) -> None: ...
def removeInput(self, i: _int) -> None: ...
def removeAllInputs(self, i: _int) -> None: ...
def hasUses(self) -> _bool: ...
def eraseOutput(self, i: _int) -> None: ...
def addOutput(self) -> Value: ...
def scopeName(self) -> str: ...
def isNondeterministic(self) -> _bool: ...
def copyAttributes(self, rhs: Node) -> Node: ...
def copyMetadata(self, rhs: Node) -> Node: ...
def hasAttributes(self) -> _bool: ...
def hasAttribute(self, name: str) -> _bool: ...
def removeAttribute(self, attr: str) -> Node: ...
def namedInput(self, name: str) -> Value: ...
def sourceRange(self) -> SourceRange: ...
def owningBlock(self) -> Block: ...
def findNode(self, kind: str, recurse: _bool = True) -> Node: ...
def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ...
def getModuleHierarchy(self) -> str: ...
def prev(self) -> Node: ...
def destroy(self) -> None: ...
# Accessors for attributes as types.
def f(self, name: str) -> _float: ...
def f_(self, name: str, val: _float) -> Node: ...
def fs(self, name: str) -> List[_float]: ...
def fs_(self, name: str, val: List[_float]) -> Node: ...
def c(self, name: str) -> complex: ...
def c_(self, name: str, val: complex) -> Node: ...
def s(self, name: str) -> str: ...
def s_(self, name: str, val: str) -> Node: ...
def ss(self, name: str) -> List[str]: ...
def ss_(self, name: str, val: List[str]) -> Node: ...
def i(self, name: str) -> _int: ...
def i_(self, name: str, val: _int) -> Node: ...
# Cannot define "is" like this because it's a reserved keyword in python.
# def is(self, name: str) -> List[_int]: ...
# def is_(self, name: str, val: List[_int]) -> Node: ...
def g(self, name: str) -> Graph: ...
def g_(self, name: str, val: Graph) -> Node: ...
def gs(self, name: str) -> List[Graph]: ...
def gs_(self, name: str, val: List[Graph]) -> Node: ...
def ival(self, name: str) -> IValue: ...
def ival_(self, name: str, val: IValue) -> Node: ...
def t(self, name: str) -> Tensor: ...
def t_(self, name: str, val: Tensor) -> Node: ...
def ts(self, name: str) -> List[Tensor]: ...
def ts_(self, name: str, val: List[Tensor]) -> Node: ...
def ty_(self, name: str, val: JitType) -> Node: ...
def tys_(self, name: str, val: List[JitType]) -> Node: ...
...
# Defined in torch/torch/csrc/jit/ir/ir.h

View File

@ -1562,6 +1562,12 @@ void ProcessConstantValueMap(Node* n, int opset_version) {
// For outputs, only update static shapes. For input, we update symbolic
// shapes also. ONNX If can have different types on different branches, skip
// here.
// Update the shape reliability for each node before processing
// ConstantValueMap to prevent unreliable nodes from producing static
// shapes
UpdateReliable(n);
auto static_input_shape = AllGraphInputsStatic(n->owningGraph());
for (auto i : c10::irange(n->outputs().size())) {
if (TensorTypePtr output_type = n->output(i)->type()->cast<TensorType>()) {

View File

@ -82,5 +82,7 @@ void UpdateReliable(
torch::jit::Value* output,
const std::pair<bool, bool>& input_reliable);
void UpdateReliable(torch::jit::Node* n);
} // namespace jit
} // namespace torch

View File

@ -8,6 +8,7 @@ import torch._C._onnx as _C_onnx
from torch.onnx._globals import GLOBALS
# TODO(#78694): Refactor the patching process to make it more transparent to users.
def _graph_op(
g: torch._C.Graph,
opname: str,

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import enum
import functools
import inspect
import sys
import warnings
from typing import Set
from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
import torch
import torch._C._onnx as _C_onnx
@ -140,7 +142,7 @@ def _get_const(value, desc, arg_name):
return _parse_arg(value, desc)
def _unpack_list(list_value):
def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
list_node = list_value.node()
assert list_node.kind() == "prim::ListConstruct"
return list(list_node.inputs())
@ -236,15 +238,19 @@ def parse_args(*arg_descriptors):
return decorator
def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
def quantized_args(
*arg_q_descriptors: bool,
scale: Optional[float] = None,
zero_point: Optional[int] = None,
):
"""A decorator which extends support for quantized version of the base operator.
Quantization is detected by examining the arguments that are annotated by
`arg_q_descriptors`.
If quantization is detected, the base operator symbolic function will be wrapped with
argument dequantization and output quantization.
argument de-quantization and output quantization.
Otherwise, only base symbolic function will be invoked.
Otherwise, only the base symbolic function will be invoked.
For example:
@ -267,11 +273,12 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
```
Args:
arg_q_descriptors: list of bool, where each element represents if the
argument is QTensor for quantized version of this operator.
scale: float default None, quantized output scale. If None, derive from
arg_q_descriptors: A sequence of bool, where each element represents if the
argument is QTensor for quantized version of this operator. It defaults
to False for unspecified (variable length) arguments.
scale: Quantized output scale. If None, derive from
the first quantized input scale.
zero_point: int default None, quantized output zero point. If None,
zero_point: Quantized output zero point. If None,
derive from the first quantized input zero point.
"""
@ -288,19 +295,21 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
if _zero_point is not None:
_zero_point = g.op("Constant", value_t=torch.tensor(_zero_point))
# some args may be optional, so the length may be smaller
assert len(arg_q_descriptors) >= len(args)
desc_args = tuple(zip(arg_q_descriptors[: len(args)], args))
# Support variable length arguments by marking unspecified ones as non-quantized
arg_q_descriptors_extended = arg_q_descriptors + (False,) * (
len(args) - len(arg_q_descriptors)
)
descriptor_args = tuple(zip(arg_q_descriptors_extended, args))
# Run regular symbolic function if none of the argument is QTensor.
if not any(
(desc and arg.node().kind() == "prim::TupleConstruct")
for desc, arg in desc_args
(descriptor and arg.node().kind() == "prim::TupleConstruct")
for descriptor, arg in descriptor_args
):
return fn(g, *args, **kwargs)
dequantized_args = []
for desc, arg in desc_args:
if desc:
for descriptor, arg in descriptor_args:
if descriptor:
dequantized_arg, scale, zero_point, _ = dequantize_helper(g, arg)
dequantized_args.append(dequantized_arg)
if _scale is None:
@ -309,7 +318,8 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
_zero_point = zero_point
else:
dequantized_args.append(arg)
# TODO: only support single output
# TODO(justinchuby): Only single output is supported for now. We may want to
# support multiple outputs in the future.
output = fn(g, *dequantized_args, **kwargs)
return quantize_helper(g, output, _scale, _zero_point)
@ -786,6 +796,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode, align_c
def _interpolate_helper(name, dim, interpolate_mode):
@quantized_args(True, False, False)
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args)
align_corners = _maybe_get_scalar(align_corners)
@ -1113,13 +1124,17 @@ def _batchnorm_helper(g, input, weight, bias, running_mean, running_var):
return weight, bias, running_mean, running_var
def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name):
def _avgpool_helper(
tuple_fn: Callable[[Any], Sequence[int]],
padding: Union[int, Sequence[int]],
kernel_size,
stride,
divisor_override,
name,
) -> Tuple[int, ...]:
if divisor_override and divisor_override.node().kind() != "prim::Constant":
return _unimplemented(name, "divisor_override")
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
return padding
_unimplemented(name, "divisor_override")
return tuple(tuple_fn(padding))
def check_training_mode(op_train_mode, op_name):
@ -1193,7 +1208,11 @@ def _handle_reduce_dim_none(g, self, op_name):
return g.op(op_name, self, keepdims_i=0)
def dequantize_helper(g, qtensor, qdtype=None):
def dequantize_helper(
g,
qtensor: _C.Value,
qdtype: Optional[torch.onnx.TensorProtoDataType] = None,
) -> Tuple[_C.Value, _C.Value, _C.Value, Optional[_C.Value]]:
"""Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`.
Args:
@ -1234,7 +1253,13 @@ def dequantize_helper(g, qtensor, qdtype=None):
)
def quantize_helper(g, tensor, scale, zero_point, axis=None):
def quantize_helper(
g,
tensor: _C.Value,
scale: _C.Value,
zero_point: _C.Value,
axis: Optional[_C.Value] = None,
) -> _C.Value:
"""Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`.
Args:
@ -1258,11 +1283,13 @@ def quantize_helper(g, tensor, scale, zero_point, axis=None):
)
assert scale is not None
if scale.type().scalarType() != "Float":
if scale.type().scalarType() != "Float": # type: ignore[attr-defined]
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
assert zero_point is not None
if zero_point.type().scalarType() not in ("Byte", "Char"):
if zero_point.type().scalarType() not in ("Byte", "Char"): # type: ignore[attr-defined]
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
output = g.op(
"QuantizeLinear",

View File

@ -1,9 +1,11 @@
import sys
import warnings
from typing import Sequence
import torch
import torch._C._onnx as _C_onnx
import torch.onnx
from torch import _C
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
@ -141,15 +143,16 @@ max_pool3d_with_indices = _max_pool(
def _avg_pool(name, tuple_fn):
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
def symbolic_fn(
g,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
input: _C.Value,
kernel_size: Sequence[int],
stride: Sequence[int],
padding: Sequence[int],
ceil_mode: int,
count_include_pad: int,
divisor_override=None,
):
if not stride:
@ -187,6 +190,7 @@ avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
def _interpolate(name, dim, interpolate_mode):
@symbolic_helper.quantized_args(True, False, False)
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = symbolic_helper._get_interpolate_attributes(
g, interpolate_mode, args
@ -543,6 +547,16 @@ class Quantized:
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
def add_relu(g, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
output = opset9.add(g, x, y)
output = opset9.relu(g, output)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
def mul(g, x, y, op_scale, op_zero_point):
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
@ -612,3 +626,19 @@ class Quantized:
)
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
@staticmethod
@symbolic_helper.parse_args("v", "i", "v", "v")
def cat(
g,
q_inputs: _C.Value,
dim: int,
op_scale: _C.Value,
op_zero_point: _C.Value,
) -> _C.Value:
unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
dequantized = [
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
]
concatenated = g.op("Concat", *dequantized, axis_i=dim)
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)

View File

@ -1,7 +1,11 @@
"""This file exports ONNX ops for opset 11."""
import sys
import warnings
from typing import Tuple, Union
import torch
from torch import _C
from torch.onnx import symbolic_helper
from torch.onnx import symbolic_opset9 as opset9
from torch.onnx import symbolic_opset10 as opset10
@ -143,7 +147,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
if len(indices_list) > 1:
for idx_ in range(len(indices_list)):
if indices_list[idx_].type().scalarType() == "Bool":
if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined]
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
indices_list[idx_] = g.op("NonZero", indices_list[idx_])
index = indices_list[0]
@ -198,7 +203,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
# return (%33)
index = indices_list[0]
bool_inp = index
if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":
if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined]
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
rank = symbolic_helper._get_tensor_rank(values)
if rank is not None and rank == 0:
return opset9.masked_fill(g, self, bool_inp, values)
@ -258,6 +264,7 @@ upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
upsample_bicubic2d = _interpolate("upsample_bicubic2d", 4, "cubic")
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
def __interpolate(
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
):
@ -415,15 +422,16 @@ def _unique2(g, self, sorted, return_inverse, return_counts):
def _avg_pool(name, tuple_fn):
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
def symbolic_fn(
g,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
input: _C.Value,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Union[int, Tuple[int, ...]],
ceil_mode: int,
count_include_pad: int,
divisor_override=None,
):
padding = symbolic_helper._avgpool_helper(

View File

@ -8,7 +8,7 @@ import functools
import math
import sys
import warnings
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
import torch._C._onnx as _C_onnx
@ -361,6 +361,8 @@ def atan(g, self):
return g.op("Atan", self)
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
def sigmoid(g, self):
return g.op("Sigmoid", self)
@ -1031,6 +1033,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
def _max_pool(name, tuple_fn, ndims, return_indices):
@symbolic_helper.quantized_args(True, False, False, False, False, False)
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if set(tuple_fn(dilation)) != {1}:
@ -1117,15 +1120,16 @@ max_pool3d_with_indices = _max_pool(
def _avg_pool(name, tuple_fn):
@symbolic_helper.quantized_args(True)
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
def symbolic_fn(
g,
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
input: _C.Value,
kernel_size: Tuple[int, ...],
stride: Tuple[int, ...],
padding: Union[int, Tuple[int, ...]],
ceil_mode: int,
count_include_pad: int,
divisor_override=None,
):
if not stride:
@ -1133,8 +1137,7 @@ def _avg_pool(name, tuple_fn):
padding = symbolic_helper._avgpool_helper(
tuple_fn, padding, kernel_size, stride, divisor_override, name
)
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
adjusted_padding = padding
if count_include_pad:
input = g.op(
"Pad",
@ -1143,17 +1146,20 @@ def _avg_pool(name, tuple_fn):
mode_s="constant",
value_f=0.0,
)
padding = (0,) * len(padding)
adjusted_padding = (0,) * len(padding)
if ceil_mode:
padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
adjusted_padding = adjusted_padding + tuple(
a + b for (a, b) in zip(padding_ceil, adjusted_padding)
)
else:
padding = padding * 2
adjusted_padding = adjusted_padding * 2
output = g.op(
"AveragePool",
input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding,
pads_i=adjusted_padding,
)
return output
@ -2510,7 +2516,8 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False):
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
if symbolic_helper._is_packed_list(data):
if dtype is None:
dtype = symbolic_helper._unpack_list(data)[0].type().scalarType()
dtype = symbolic_helper._unpack_list(data)[0].type().scalarType() # type: ignore[attr-defined]
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
dtype = symbolic_helper.scalar_type_to_onnx.index(
symbolic_helper.cast_pytorch_to_onnx[dtype]
)
@ -4963,7 +4970,12 @@ class Prim:
return None
@staticmethod
def ListUnpack(g, *inputs, **kwargs):
def ListUnpack(g, *inputs, **kwargs) -> Optional[List[_C.Value]]:
if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct":
# Cancel the previous node if it is ListConstruct by returning its inputs
# TODO(justinchuby): Use a public method in the helper module
return symbolic_helper._unpack_list(inputs[0])
return None
@staticmethod

View File

@ -12192,6 +12192,8 @@ op_db: List[OpInfo] = [
'TestCommon', 'test_noncontiguous_samples',
device_type='cpu'), ],
skips=(
# Strides are not the same!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
# https://github.com/pytorch/pytorch/issues/67470
DecorateInfo(unittest.skip("67470!"),
'TestCommon', 'test_noncontiguous_samples',
@ -13517,6 +13519,8 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
),
decorators=(
# Strides are not the same!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
)),