mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 03:04:55 +08:00
Compare commits
5 Commits
v1.12.1-rc
...
release/1.
| Author | SHA1 | Date | |
|---|---|---|---|
| 664058fa83 | |||
| efc2d08eac | |||
| 9a9dcebfd5 | |||
| 617c4fe52c | |||
| f469bc1fe1 |
1
.github/actionlint.yaml
vendored
1
.github/actionlint.yaml
vendored
@ -12,5 +12,6 @@ self-hosted-runner:
|
||||
- windows.8xlarge.nvidia.gpu
|
||||
- bm-runner
|
||||
- linux.rocm.gpu
|
||||
- macos-12-xl
|
||||
- macos-12
|
||||
- macos12.3-m1
|
||||
|
||||
@ -64,7 +64,7 @@ jobs:
|
||||
{%- if config["package_type"] == "libtorch" %}
|
||||
runs-on: macos-10.15
|
||||
{%- else %}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
{%- endif %}
|
||||
{%- if config["package_type"] == "libtorch" %}
|
||||
# libtorch builds take a long time on github hosted runners
|
||||
|
||||
6
.github/workflows/generated-macos-arm64-binary-conda-nightly.yml
generated
vendored
6
.github/workflows/generated-macos-arm64-binary-conda-nightly.yml
generated
vendored
@ -39,7 +39,7 @@ concurrency:
|
||||
jobs:
|
||||
conda-py3_8-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -213,7 +213,7 @@ jobs:
|
||||
docker system prune -af
|
||||
conda-py3_9-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -387,7 +387,7 @@ jobs:
|
||||
docker system prune -af
|
||||
conda-py3_10-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
|
||||
8
.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
generated
vendored
8
.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
generated
vendored
@ -39,7 +39,7 @@ concurrency:
|
||||
jobs:
|
||||
wheel-py3_7-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -213,7 +213,7 @@ jobs:
|
||||
docker system prune -af
|
||||
wheel-py3_8-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -387,7 +387,7 @@ jobs:
|
||||
docker system prune -af
|
||||
wheel-py3_9-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -561,7 +561,7 @@ jobs:
|
||||
docker system prune -af
|
||||
wheel-py3_10-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
|
||||
8
.github/workflows/generated-macos-binary-conda-nightly.yml
generated
vendored
8
.github/workflows/generated-macos-binary-conda-nightly.yml
generated
vendored
@ -37,7 +37,7 @@ concurrency:
|
||||
jobs:
|
||||
conda-py3_7-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -211,7 +211,7 @@ jobs:
|
||||
docker system prune -af
|
||||
conda-py3_8-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -385,7 +385,7 @@ jobs:
|
||||
docker system prune -af
|
||||
conda-py3_9-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -559,7 +559,7 @@ jobs:
|
||||
docker system prune -af
|
||||
conda-py3_10-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
|
||||
8
.github/workflows/generated-macos-binary-wheel-nightly.yml
generated
vendored
8
.github/workflows/generated-macos-binary-wheel-nightly.yml
generated
vendored
@ -37,7 +37,7 @@ concurrency:
|
||||
jobs:
|
||||
wheel-py3_7-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -211,7 +211,7 @@ jobs:
|
||||
docker system prune -af
|
||||
wheel-py3_8-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -385,7 +385,7 @@ jobs:
|
||||
docker system prune -af
|
||||
wheel-py3_9-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
@ -559,7 +559,7 @@ jobs:
|
||||
docker system prune -af
|
||||
wheel-py3_10-cpu-build:
|
||||
if: ${{ github.repository_owner == 'pytorch' }}
|
||||
runs-on: macos-12
|
||||
runs-on: macos-12-xl
|
||||
timeout-minutes: 240
|
||||
env:
|
||||
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
|
||||
|
||||
@ -19,7 +19,7 @@ if "%INSTALL_FRESH_CONDA%"=="1" (
|
||||
|
||||
call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3
|
||||
if "%INSTALL_FRESH_CONDA%"=="1" (
|
||||
call conda install -y -q python=%PYTHON_VERSION% numpy cffi pyyaml boto3 libuv
|
||||
call conda install -y -q python=%PYTHON_VERSION% numpy"<1.23" cffi pyyaml boto3 libuv
|
||||
if errorlevel 1 exit /b
|
||||
if not errorlevel 0 exit /b
|
||||
call conda install -y -q -c conda-forge cmake=3.22.3
|
||||
|
||||
@ -1712,7 +1712,8 @@ Tensor _matmul_impl(
|
||||
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
|
||||
return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
|
||||
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
|
||||
return has_out ? at::mv_out(out, tensor2.t(), tensor1) : tensor2.t().mv(tensor1);
|
||||
return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
|
||||
: tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
|
||||
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
|
||||
return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
|
||||
} else if (should_fold(tensor1, dim_tensor2) || should_fold(tensor2, dim_tensor1)) {
|
||||
|
||||
@ -123,12 +123,14 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
|
||||
for (int it = 0; it < WARP_ITERATIONS; ++it) {
|
||||
if (is_masked) {
|
||||
int idx = it*WARP_SIZE;
|
||||
if (!is_transformer_mask) {
|
||||
idx += i*element_count;
|
||||
}
|
||||
if (!mask[idx]) {
|
||||
max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
is_meaningful_max = true;
|
||||
if ((idx + local_idx) < element_count) {
|
||||
if (!is_transformer_mask) {
|
||||
idx += i*element_count;
|
||||
}
|
||||
if (!mask[idx]) {
|
||||
max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
|
||||
is_meaningful_max = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
|
||||
@ -156,22 +158,28 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
|
||||
}
|
||||
} else {
|
||||
int idx = it*WARP_SIZE;
|
||||
bool valid = (idx + local_idx) < element_count;
|
||||
if (!is_transformer_mask) {
|
||||
idx += i*element_count;
|
||||
}
|
||||
|
||||
if (!mask[idx]) {
|
||||
if (is_log_softmax) {
|
||||
sum[i] += std::exp(elements[i][it] - max_value[i]);
|
||||
if (valid) {
|
||||
if (!mask[idx]) {
|
||||
if (is_log_softmax) {
|
||||
sum[i] += std::exp(elements[i][it] - max_value[i]);
|
||||
} else {
|
||||
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
|
||||
sum[i] += elements[i][it];
|
||||
}
|
||||
} else {
|
||||
elements[i][it] = std::exp(elements[i][it] - max_value[i]);
|
||||
sum[i] += elements[i][it];
|
||||
if (!is_log_softmax) {
|
||||
// Masked values are treated as -infinity, and std::exp(-infinity) is 0.
|
||||
elements[i][it] = 0;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (!is_log_softmax) {
|
||||
// Masked values are treated as -infinity, and std::exp(-infinity) is 0.
|
||||
elements[i][it] = 0;
|
||||
}
|
||||
if (!is_log_softmax) {
|
||||
elements[i][it] = 0.;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,68 +0,0 @@
|
||||
ir_version: 3
|
||||
producer_name: "pytorch"
|
||||
producer_version: "CURRENT_VERSION"
|
||||
graph {
|
||||
node {
|
||||
input: "onnx::Shape_0"
|
||||
output: "onnx::ConstantFill_1"
|
||||
name: "Shape_0"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
input: "onnx::ConstantFill_1"
|
||||
output: "2"
|
||||
name: "ConstantFill_1"
|
||||
op_type: "ConstantFill"
|
||||
attribute {
|
||||
name: "dtype"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "input_as_shape"
|
||||
i: 1
|
||||
type: INT
|
||||
}
|
||||
attribute {
|
||||
name: "value"
|
||||
f: 0
|
||||
type: FLOAT
|
||||
}
|
||||
}
|
||||
name: "torch_jit"
|
||||
input {
|
||||
name: "onnx::Shape_0"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
dim {
|
||||
dim_value: 8
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "2"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 5
|
||||
}
|
||||
dim {
|
||||
dim_value: 8
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 7
|
||||
}
|
||||
32
test/onnx/symbolic_opsets/test_symbolic_opset9.py
Normal file
32
test/onnx/symbolic_opsets/test_symbolic_opset9.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Owner(s): ["module: onnx"]
|
||||
|
||||
"""Tests for `torch.onnx.symbolic_opset9`."""
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_opset9 as opset9
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
class TestPrim(common_utils.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.graph = _C.Graph()
|
||||
|
||||
def test_list_unpack_returns_all_list_elements_when_previous_node_is_list_construct(
|
||||
self,
|
||||
):
|
||||
# Build the graph
|
||||
input_1 = self.graph.addInput()
|
||||
input_1.setType(input_1.type().with_dtype(torch.float).with_sizes([2, 42]))
|
||||
input_2 = self.graph.addInput()
|
||||
input_2.setType(input_2.type().with_dtype(torch.float).with_sizes([3, 42]))
|
||||
constructed_list = self.graph.op("prim::ListConstruct", input_1, input_2)
|
||||
# Test the op
|
||||
outputs = opset9.Prim.ListUnpack(self.graph, constructed_list)
|
||||
self.assertNotEqual(outputs, None)
|
||||
self.assertEqual(outputs[0], input_1)
|
||||
self.assertEqual(outputs[1], input_2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
@ -737,10 +737,6 @@ class TestOperators(TestCase):
|
||||
x = torch.randn(5, 8, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.empty_like(x), x)
|
||||
|
||||
def test_empty_like_opset7(self):
|
||||
x = torch.randn(5, 8, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.empty_like(x), x, opset_version=7)
|
||||
|
||||
def test_zeros_like(self):
|
||||
x = torch.randn(5, 8, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.zeros_like(x), x)
|
||||
|
||||
@ -7,6 +7,7 @@ import unittest
|
||||
from typing import Optional, Type
|
||||
|
||||
import onnx
|
||||
import onnx.numpy_helper
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -106,9 +107,75 @@ class TestOptionalOutput(unittest.TestCase):
|
||||
input_names=["y"],
|
||||
)
|
||||
|
||||
def test_maintain_dynamic_shapes_of_unreliable_nodes(self):
|
||||
def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs):
|
||||
return g.op("com.microsoft::PythonOp")
|
||||
|
||||
torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1)
|
||||
self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1)
|
||||
|
||||
# necessay parameters for transformer embeddings
|
||||
hidden_size = 48
|
||||
max_position_embeddings = 32
|
||||
batch_size = 2
|
||||
|
||||
# issue found that autograd.function making downstream
|
||||
# node unreliable but with static shape. The issue was first
|
||||
# discovered with using Apex FusedLayerNorm in Transformers
|
||||
class CustomLayerNorm(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, embedding):
|
||||
layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12)
|
||||
return layer_norm(embedding)
|
||||
|
||||
class EmbeddingModule(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
embeddings=None,
|
||||
):
|
||||
embedding_output = CustomLayerNorm.apply(embeddings)
|
||||
query = embedding_output.transpose(0, 1)
|
||||
target_len, batch_size, embedding_dim = query.size()
|
||||
# Reshape is used for consuming batch_size, and if it is static,
|
||||
# this will be a Constant node in the graph
|
||||
query = query.reshape(target_len, batch_size, embedding_dim)
|
||||
return query
|
||||
|
||||
embeddings = torch.randn(batch_size, max_position_embeddings, hidden_size)
|
||||
|
||||
f = io.BytesIO()
|
||||
torch.onnx.export(
|
||||
EmbeddingModule().eval(),
|
||||
(embeddings,),
|
||||
f,
|
||||
input_names=["embeddings"],
|
||||
dynamic_axes={
|
||||
"embeddings": {
|
||||
0: "batch_size",
|
||||
1: "max_position_embeddings",
|
||||
2: "hidden_size",
|
||||
}
|
||||
},
|
||||
custom_opsets={"com.microsoft": 1},
|
||||
)
|
||||
model = onnx.load(io.BytesIO(f.getvalue()))
|
||||
|
||||
# If there is a constant node with dim=3 and max_position_embeddings,
|
||||
# batch_size, hidden_size as shape, it means the shape becomes static.
|
||||
# Normally, with dynamic batch size, this constant node should not exist.
|
||||
const_node = [n for n in model.graph.node if n.op_type == "Constant"]
|
||||
self.assertNotEqual(len(const_node), 0)
|
||||
for node in const_node:
|
||||
for a in node.attribute:
|
||||
if a.name == "value":
|
||||
shape = onnx.numpy_helper.to_array(a.t)
|
||||
self.assertNotEqual(
|
||||
shape.tolist(),
|
||||
[max_position_embeddings, batch_size, hidden_size],
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestOptionalOutput)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -3456,28 +3456,68 @@ class _TestONNXRuntime:
|
||||
self.run_test(model, x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_listunpack(self):
|
||||
class ListUnpack(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def test_list_unpack_scripted(self):
|
||||
class ListUnpack(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
a, b = x.shape
|
||||
return x.new_zeros((a, b))
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
self.run_test(ListUnpack(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
|
||||
self.run_test(ListUnpack(), x, remained_onnx_input_idx=[])
|
||||
self.run_test(
|
||||
torch.jit.script(ListUnpack()),
|
||||
x,
|
||||
input_names=["x"],
|
||||
dynamic_axes={"x": [0, 1]},
|
||||
)
|
||||
self.run_test(torch.jit.script(ListUnpack()), x, remained_onnx_input_idx=[])
|
||||
|
||||
class ListUnpackSlice(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_list_unpack_scripted_runs_without_error_with_constructed_list_as_input(
|
||||
self,
|
||||
):
|
||||
class PackUnpack(torch.nn.Module):
|
||||
"""Create and unpack a list of tensors.
|
||||
|
||||
When scripted, it should produce a graph similar to
|
||||
|
||||
```
|
||||
graph(%self : __torch__.PackUnpack,
|
||||
%a.1 : Tensor,
|
||||
%b.1 : Tensor):
|
||||
%packed.1 : Tensor[] = prim::ListConstruct(%a.1, %b.1)
|
||||
%c.1 : Tensor, %8 : Tensor = prim::ListUnpack(%packed.1)
|
||||
return (%c.1)
|
||||
```
|
||||
"""
|
||||
|
||||
def forward(self, a, b):
|
||||
packed = [a, b]
|
||||
c, _ = packed
|
||||
return c
|
||||
|
||||
self.run_test(
|
||||
torch.jit.script(PackUnpack()),
|
||||
(torch.tensor(0), torch.tensor([42])),
|
||||
remained_onnx_input_idx=[0],
|
||||
)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_list_unpack_slice_scripted(self):
|
||||
class ListUnpackSlice(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
a, b = x.shape[2:]
|
||||
return x.new_zeros((a, b))
|
||||
|
||||
x = torch.randn(2, 3, 4, 5)
|
||||
self.run_test(
|
||||
ListUnpackSlice(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}
|
||||
torch.jit.script(ListUnpackSlice()),
|
||||
x,
|
||||
input_names=["x"],
|
||||
dynamic_axes={"x": [0, 1, 2, 3]},
|
||||
)
|
||||
self.run_test(
|
||||
torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
|
||||
)
|
||||
self.run_test(ListUnpackSlice(), x, remained_onnx_input_idx=[])
|
||||
|
||||
def test_pow(self):
|
||||
class PowModule(torch.nn.Module):
|
||||
@ -12366,6 +12406,13 @@ class _TestONNXRuntime:
|
||||
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
|
||||
self.run_test(model, q_input)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_sigmoid(self):
|
||||
model = torch.nn.Sigmoid()
|
||||
input = torch.randn(2, 6)
|
||||
q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
|
||||
self.run_test(model, q_input)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_quantized_flatten(self):
|
||||
class FlattenModel(torch.nn.Module):
|
||||
@ -12375,6 +12422,19 @@ class _TestONNXRuntime:
|
||||
x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
|
||||
self.run_test(FlattenModel(), x)
|
||||
|
||||
@unittest.skip(
|
||||
"ONNX Runtime 1.11 does not support quantized cat. Enable after ORT 1.12 is enabled in CI."
|
||||
)
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
@skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
|
||||
def test_quantized_cat(self):
|
||||
class QuantizedConcatenationModel(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.nn.quantized.QFunctional().cat((x, x), dim=1)
|
||||
|
||||
q_input = torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 128, torch.quint8)
|
||||
self.run_test(QuantizedConcatenationModel(), q_input)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
@skipScriptTest() # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
|
||||
def test_quantized_arithmetic_qfunctional(self):
|
||||
@ -12596,6 +12656,32 @@ class _TestONNXRuntime:
|
||||
input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
|
||||
self.run_test(model, input)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_qat_avg_pool2d(self):
|
||||
model = torch.nn.Sequential(
|
||||
torch.quantization.QuantStub(),
|
||||
torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
|
||||
torch.quantization.DeQuantStub(),
|
||||
)
|
||||
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
model = torch.quantization.prepare_qat(model.train())
|
||||
model = torch.quantization.convert(model)
|
||||
input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
|
||||
self.run_test(model, input)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(11)
|
||||
def test_qat_upsample_nearest2d(self):
|
||||
model = torch.nn.Sequential(
|
||||
torch.quantization.QuantStub(),
|
||||
torch.nn.UpsamplingNearest2d(scale_factor=1.5),
|
||||
torch.quantization.DeQuantStub(),
|
||||
)
|
||||
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
model = torch.quantization.prepare_qat(model.train())
|
||||
model = torch.quantization.convert(model)
|
||||
input = _construct_tensor_for_quantization_test((4, 3, 2, 2))
|
||||
self.run_test(model, input)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_convolution_allow_tf32(self):
|
||||
class Module(torch.nn.Module):
|
||||
|
||||
@ -451,11 +451,34 @@ class _InsertPoint:
|
||||
def __enter__(self) -> None: ...
|
||||
def __exit__(self, *args) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/ir/ir.h
|
||||
class Use:
|
||||
@property
|
||||
def user(self) -> Node: ...
|
||||
@property
|
||||
def offset(self) -> _int: ...
|
||||
def isAfter(self, other: Use) -> _bool: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/ir/ir.h
|
||||
class Value:
|
||||
def type(self)-> JitType: ...
|
||||
def setType(self, t: JitType) -> Value: ...
|
||||
def setTypeAs(self, other: Value) -> Value: ...
|
||||
def inferTypeFrom(self, t: Tensor) -> None: ...
|
||||
def debugName(self) -> str: ...
|
||||
def setDebugName(self, name: str) -> None: ...
|
||||
def unique(self) -> _int: ...
|
||||
def offset(self) -> _int: ...
|
||||
def node(self) -> Node: ...
|
||||
def uses(self) -> List[Use]: ...
|
||||
def replaceAllUsesWith(self, val: Value) -> None: ...
|
||||
def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ...
|
||||
def requires_grad(self) -> _bool: ...
|
||||
def requiresGrad(self) -> _bool: ...
|
||||
def copyMetadata(self, other: Value) -> Value: ...
|
||||
def isCompleteTensor(self) -> _bool: ...
|
||||
def toIValue(self) -> IValue: ...
|
||||
...
|
||||
|
||||
# Defined in torch/csrc/jit/ir/ir.h
|
||||
@ -467,14 +490,80 @@ class Block:
|
||||
# Defined in torch/csrc/jit/ir/ir.h
|
||||
class Node:
|
||||
def schema(self) -> str: ...
|
||||
def input(self) -> Value: ...
|
||||
def inputs(self) -> List[Value]: ...
|
||||
def inputsAt(self, idx: _int) -> Value: ...
|
||||
def inputsSize(self) -> _int: ...
|
||||
def output(self) -> Value: ...
|
||||
def outputs(self) -> List[Value]: ...
|
||||
def outputsAt(self, idx: _int) -> Value: ...
|
||||
def outputsSize(self) -> _int: ...
|
||||
def hasMultipleOutputs(self) -> _bool: ...
|
||||
def blocks(self) -> List[Block]: ...
|
||||
def mustBeNone(self) -> _bool: ...
|
||||
def kindOf(self, str) -> str: ...
|
||||
def __getitem__(self, key: str) -> Any: ...
|
||||
def namedInput(self, str) -> Value: ...
|
||||
def matches(self, pattern: str) -> _bool: ...
|
||||
def kind(self) -> str: ...
|
||||
def kindOf(self, name: str) -> str: ...
|
||||
def addInput(self, name: str) -> Value: ...
|
||||
def replaceInput(self, i: _int, newValue: Value) -> Value: ...
|
||||
def replaceInputWith(self, from_: Value, to: Value) -> None: ...
|
||||
def replaceAllUsesWith(self, n: Node) -> None: ...
|
||||
def insertBefore(self, n: Node) -> Node: ...
|
||||
def insertAfter(self, n: Node) -> Node: ...
|
||||
def isBefore(self, n: Node) -> _bool: ...
|
||||
def isAfter(self, n: Node) -> _bool: ...
|
||||
def moveBefore(self, n: Node) -> None: ...
|
||||
def moveAfter(self, n: Node) -> None: ...
|
||||
def removeInput(self, i: _int) -> None: ...
|
||||
def removeAllInputs(self, i: _int) -> None: ...
|
||||
def hasUses(self) -> _bool: ...
|
||||
def eraseOutput(self, i: _int) -> None: ...
|
||||
def addOutput(self) -> Value: ...
|
||||
def scopeName(self) -> str: ...
|
||||
def isNondeterministic(self) -> _bool: ...
|
||||
def copyAttributes(self, rhs: Node) -> Node: ...
|
||||
def copyMetadata(self, rhs: Node) -> Node: ...
|
||||
def hasAttributes(self) -> _bool: ...
|
||||
def hasAttribute(self, name: str) -> _bool: ...
|
||||
def removeAttribute(self, attr: str) -> Node: ...
|
||||
def namedInput(self, name: str) -> Value: ...
|
||||
def sourceRange(self) -> SourceRange: ...
|
||||
def owningBlock(self) -> Block: ...
|
||||
def findNode(self, kind: str, recurse: _bool = True) -> Node: ...
|
||||
def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ...
|
||||
def getModuleHierarchy(self) -> str: ...
|
||||
def prev(self) -> Node: ...
|
||||
def destroy(self) -> None: ...
|
||||
|
||||
# Accessors for attributes as types.
|
||||
def f(self, name: str) -> _float: ...
|
||||
def f_(self, name: str, val: _float) -> Node: ...
|
||||
def fs(self, name: str) -> List[_float]: ...
|
||||
def fs_(self, name: str, val: List[_float]) -> Node: ...
|
||||
def c(self, name: str) -> complex: ...
|
||||
def c_(self, name: str, val: complex) -> Node: ...
|
||||
def s(self, name: str) -> str: ...
|
||||
def s_(self, name: str, val: str) -> Node: ...
|
||||
def ss(self, name: str) -> List[str]: ...
|
||||
def ss_(self, name: str, val: List[str]) -> Node: ...
|
||||
def i(self, name: str) -> _int: ...
|
||||
def i_(self, name: str, val: _int) -> Node: ...
|
||||
# Cannot define "is" like this because it's a reserved keyword in python.
|
||||
# def is(self, name: str) -> List[_int]: ...
|
||||
# def is_(self, name: str, val: List[_int]) -> Node: ...
|
||||
def g(self, name: str) -> Graph: ...
|
||||
def g_(self, name: str, val: Graph) -> Node: ...
|
||||
def gs(self, name: str) -> List[Graph]: ...
|
||||
def gs_(self, name: str, val: List[Graph]) -> Node: ...
|
||||
def ival(self, name: str) -> IValue: ...
|
||||
def ival_(self, name: str, val: IValue) -> Node: ...
|
||||
def t(self, name: str) -> Tensor: ...
|
||||
def t_(self, name: str, val: Tensor) -> Node: ...
|
||||
def ts(self, name: str) -> List[Tensor]: ...
|
||||
def ts_(self, name: str, val: List[Tensor]) -> Node: ...
|
||||
def ty_(self, name: str, val: JitType) -> Node: ...
|
||||
def tys_(self, name: str, val: List[JitType]) -> Node: ...
|
||||
...
|
||||
|
||||
# Defined in torch/torch/csrc/jit/ir/ir.h
|
||||
|
||||
@ -1562,6 +1562,12 @@ void ProcessConstantValueMap(Node* n, int opset_version) {
|
||||
// For outputs, only update static shapes. For input, we update symbolic
|
||||
// shapes also. ONNX If can have different types on different branches, skip
|
||||
// here.
|
||||
|
||||
// Update the shape reliability for each node before processing
|
||||
// ConstantValueMap to prevent unreliable nodes from producing static
|
||||
// shapes
|
||||
UpdateReliable(n);
|
||||
|
||||
auto static_input_shape = AllGraphInputsStatic(n->owningGraph());
|
||||
for (auto i : c10::irange(n->outputs().size())) {
|
||||
if (TensorTypePtr output_type = n->output(i)->type()->cast<TensorType>()) {
|
||||
|
||||
@ -82,5 +82,7 @@ void UpdateReliable(
|
||||
torch::jit::Value* output,
|
||||
const std::pair<bool, bool>& input_reliable);
|
||||
|
||||
void UpdateReliable(torch::jit::Node* n);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -8,6 +8,7 @@ import torch._C._onnx as _C_onnx
|
||||
from torch.onnx._globals import GLOBALS
|
||||
|
||||
|
||||
# TODO(#78694): Refactor the patching process to make it more transparent to users.
|
||||
def _graph_op(
|
||||
g: torch._C.Graph,
|
||||
opname: str,
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Set
|
||||
from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
@ -140,7 +142,7 @@ def _get_const(value, desc, arg_name):
|
||||
return _parse_arg(value, desc)
|
||||
|
||||
|
||||
def _unpack_list(list_value):
|
||||
def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
|
||||
list_node = list_value.node()
|
||||
assert list_node.kind() == "prim::ListConstruct"
|
||||
return list(list_node.inputs())
|
||||
@ -236,15 +238,19 @@ def parse_args(*arg_descriptors):
|
||||
return decorator
|
||||
|
||||
|
||||
def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
|
||||
def quantized_args(
|
||||
*arg_q_descriptors: bool,
|
||||
scale: Optional[float] = None,
|
||||
zero_point: Optional[int] = None,
|
||||
):
|
||||
"""A decorator which extends support for quantized version of the base operator.
|
||||
Quantization is detected by examining the arguments that are annotated by
|
||||
`arg_q_descriptors`.
|
||||
|
||||
If quantization is detected, the base operator symbolic function will be wrapped with
|
||||
argument dequantization and output quantization.
|
||||
argument de-quantization and output quantization.
|
||||
|
||||
Otherwise, only base symbolic function will be invoked.
|
||||
Otherwise, only the base symbolic function will be invoked.
|
||||
|
||||
For example:
|
||||
|
||||
@ -267,11 +273,12 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
|
||||
```
|
||||
|
||||
Args:
|
||||
arg_q_descriptors: list of bool, where each element represents if the
|
||||
argument is QTensor for quantized version of this operator.
|
||||
scale: float default None, quantized output scale. If None, derive from
|
||||
arg_q_descriptors: A sequence of bool, where each element represents if the
|
||||
argument is QTensor for quantized version of this operator. It defaults
|
||||
to False for unspecified (variable length) arguments.
|
||||
scale: Quantized output scale. If None, derive from
|
||||
the first quantized input scale.
|
||||
zero_point: int default None, quantized output zero point. If None,
|
||||
zero_point: Quantized output zero point. If None,
|
||||
derive from the first quantized input zero point.
|
||||
"""
|
||||
|
||||
@ -288,19 +295,21 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
|
||||
if _zero_point is not None:
|
||||
_zero_point = g.op("Constant", value_t=torch.tensor(_zero_point))
|
||||
|
||||
# some args may be optional, so the length may be smaller
|
||||
assert len(arg_q_descriptors) >= len(args)
|
||||
desc_args = tuple(zip(arg_q_descriptors[: len(args)], args))
|
||||
# Support variable length arguments by marking unspecified ones as non-quantized
|
||||
arg_q_descriptors_extended = arg_q_descriptors + (False,) * (
|
||||
len(args) - len(arg_q_descriptors)
|
||||
)
|
||||
descriptor_args = tuple(zip(arg_q_descriptors_extended, args))
|
||||
# Run regular symbolic function if none of the argument is QTensor.
|
||||
if not any(
|
||||
(desc and arg.node().kind() == "prim::TupleConstruct")
|
||||
for desc, arg in desc_args
|
||||
(descriptor and arg.node().kind() == "prim::TupleConstruct")
|
||||
for descriptor, arg in descriptor_args
|
||||
):
|
||||
return fn(g, *args, **kwargs)
|
||||
|
||||
dequantized_args = []
|
||||
for desc, arg in desc_args:
|
||||
if desc:
|
||||
for descriptor, arg in descriptor_args:
|
||||
if descriptor:
|
||||
dequantized_arg, scale, zero_point, _ = dequantize_helper(g, arg)
|
||||
dequantized_args.append(dequantized_arg)
|
||||
if _scale is None:
|
||||
@ -309,7 +318,8 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
|
||||
_zero_point = zero_point
|
||||
else:
|
||||
dequantized_args.append(arg)
|
||||
# TODO: only support single output
|
||||
# TODO(justinchuby): Only single output is supported for now. We may want to
|
||||
# support multiple outputs in the future.
|
||||
output = fn(g, *dequantized_args, **kwargs)
|
||||
|
||||
return quantize_helper(g, output, _scale, _zero_point)
|
||||
@ -786,6 +796,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode, align_c
|
||||
|
||||
|
||||
def _interpolate_helper(name, dim, interpolate_mode):
|
||||
@quantized_args(True, False, False)
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args)
|
||||
align_corners = _maybe_get_scalar(align_corners)
|
||||
@ -1113,13 +1124,17 @@ def _batchnorm_helper(g, input, weight, bias, running_mean, running_var):
|
||||
return weight, bias, running_mean, running_var
|
||||
|
||||
|
||||
def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name):
|
||||
def _avgpool_helper(
|
||||
tuple_fn: Callable[[Any], Sequence[int]],
|
||||
padding: Union[int, Sequence[int]],
|
||||
kernel_size,
|
||||
stride,
|
||||
divisor_override,
|
||||
name,
|
||||
) -> Tuple[int, ...]:
|
||||
if divisor_override and divisor_override.node().kind() != "prim::Constant":
|
||||
return _unimplemented(name, "divisor_override")
|
||||
if not stride:
|
||||
stride = kernel_size
|
||||
padding = tuple(tuple_fn(padding))
|
||||
return padding
|
||||
_unimplemented(name, "divisor_override")
|
||||
return tuple(tuple_fn(padding))
|
||||
|
||||
|
||||
def check_training_mode(op_train_mode, op_name):
|
||||
@ -1193,7 +1208,11 @@ def _handle_reduce_dim_none(g, self, op_name):
|
||||
return g.op(op_name, self, keepdims_i=0)
|
||||
|
||||
|
||||
def dequantize_helper(g, qtensor, qdtype=None):
|
||||
def dequantize_helper(
|
||||
g,
|
||||
qtensor: _C.Value,
|
||||
qdtype: Optional[torch.onnx.TensorProtoDataType] = None,
|
||||
) -> Tuple[_C.Value, _C.Value, _C.Value, Optional[_C.Value]]:
|
||||
"""Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`.
|
||||
|
||||
Args:
|
||||
@ -1234,7 +1253,13 @@ def dequantize_helper(g, qtensor, qdtype=None):
|
||||
)
|
||||
|
||||
|
||||
def quantize_helper(g, tensor, scale, zero_point, axis=None):
|
||||
def quantize_helper(
|
||||
g,
|
||||
tensor: _C.Value,
|
||||
scale: _C.Value,
|
||||
zero_point: _C.Value,
|
||||
axis: Optional[_C.Value] = None,
|
||||
) -> _C.Value:
|
||||
"""Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`.
|
||||
|
||||
Args:
|
||||
@ -1258,11 +1283,13 @@ def quantize_helper(g, tensor, scale, zero_point, axis=None):
|
||||
)
|
||||
|
||||
assert scale is not None
|
||||
if scale.type().scalarType() != "Float":
|
||||
if scale.type().scalarType() != "Float": # type: ignore[attr-defined]
|
||||
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
|
||||
scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
|
||||
|
||||
assert zero_point is not None
|
||||
if zero_point.type().scalarType() not in ("Byte", "Char"):
|
||||
if zero_point.type().scalarType() not in ("Byte", "Char"): # type: ignore[attr-defined]
|
||||
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
|
||||
zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
|
||||
output = g.op(
|
||||
"QuantizeLinear",
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
import torch.onnx
|
||||
from torch import _C
|
||||
|
||||
# This import monkey-patches graph manipulation methods on Graph, used for the
|
||||
# ONNX symbolics
|
||||
@ -141,15 +143,16 @@ max_pool3d_with_indices = _max_pool(
|
||||
|
||||
|
||||
def _avg_pool(name, tuple_fn):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||||
def symbolic_fn(
|
||||
g,
|
||||
input,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
input: _C.Value,
|
||||
kernel_size: Sequence[int],
|
||||
stride: Sequence[int],
|
||||
padding: Sequence[int],
|
||||
ceil_mode: int,
|
||||
count_include_pad: int,
|
||||
divisor_override=None,
|
||||
):
|
||||
if not stride:
|
||||
@ -187,6 +190,7 @@ avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
|
||||
|
||||
|
||||
def _interpolate(name, dim, interpolate_mode):
|
||||
@symbolic_helper.quantized_args(True, False, False)
|
||||
def symbolic_fn(g, input, output_size, *args):
|
||||
scales, align_corners = symbolic_helper._get_interpolate_attributes(
|
||||
g, interpolate_mode, args
|
||||
@ -543,6 +547,16 @@ class Quantized:
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
def add_relu(g, x, y, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
|
||||
|
||||
output = opset9.add(g, x, y)
|
||||
output = opset9.relu(g, output)
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
def mul(g, x, y, op_scale, op_zero_point):
|
||||
x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
|
||||
@ -612,3 +626,19 @@ class Quantized:
|
||||
)
|
||||
|
||||
return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|
||||
|
||||
@staticmethod
|
||||
@symbolic_helper.parse_args("v", "i", "v", "v")
|
||||
def cat(
|
||||
g,
|
||||
q_inputs: _C.Value,
|
||||
dim: int,
|
||||
op_scale: _C.Value,
|
||||
op_zero_point: _C.Value,
|
||||
) -> _C.Value:
|
||||
unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
|
||||
dequantized = [
|
||||
symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
|
||||
]
|
||||
concatenated = g.op("Concat", *dequantized, axis_i=dim)
|
||||
return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
"""This file exports ONNX ops for opset 11."""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.onnx import symbolic_helper
|
||||
from torch.onnx import symbolic_opset9 as opset9
|
||||
from torch.onnx import symbolic_opset10 as opset10
|
||||
@ -143,7 +147,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
|
||||
|
||||
if len(indices_list) > 1:
|
||||
for idx_ in range(len(indices_list)):
|
||||
if indices_list[idx_].type().scalarType() == "Bool":
|
||||
if indices_list[idx_].type().scalarType() == "Bool": # type: ignore[attr-defined]
|
||||
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
|
||||
indices_list[idx_] = g.op("NonZero", indices_list[idx_])
|
||||
index = indices_list[0]
|
||||
|
||||
@ -198,7 +203,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
|
||||
# return (%33)
|
||||
index = indices_list[0]
|
||||
bool_inp = index
|
||||
if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":
|
||||
if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool": # type: ignore[attr-defined]
|
||||
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
|
||||
rank = symbolic_helper._get_tensor_rank(values)
|
||||
if rank is not None and rank == 0:
|
||||
return opset9.masked_fill(g, self, bool_inp, values)
|
||||
@ -258,6 +264,7 @@ upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
|
||||
upsample_bicubic2d = _interpolate("upsample_bicubic2d", 4, "cubic")
|
||||
|
||||
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||||
def __interpolate(
|
||||
g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
||||
):
|
||||
@ -415,15 +422,16 @@ def _unique2(g, self, sorted, return_inverse, return_counts):
|
||||
|
||||
|
||||
def _avg_pool(name, tuple_fn):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||||
def symbolic_fn(
|
||||
g,
|
||||
input,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
input: _C.Value,
|
||||
kernel_size: Tuple[int, ...],
|
||||
stride: Tuple[int, ...],
|
||||
padding: Union[int, Tuple[int, ...]],
|
||||
ceil_mode: int,
|
||||
count_include_pad: int,
|
||||
divisor_override=None,
|
||||
):
|
||||
padding = symbolic_helper._avgpool_helper(
|
||||
|
||||
@ -8,7 +8,7 @@ import functools
|
||||
import math
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch._C._onnx as _C_onnx
|
||||
@ -361,6 +361,8 @@ def atan(g, self):
|
||||
return g.op("Atan", self)
|
||||
|
||||
|
||||
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
|
||||
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
|
||||
def sigmoid(g, self):
|
||||
return g.op("Sigmoid", self)
|
||||
|
||||
@ -1031,6 +1033,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
|
||||
|
||||
|
||||
def _max_pool(name, tuple_fn, ndims, return_indices):
|
||||
@symbolic_helper.quantized_args(True, False, False, False, False, False)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
|
||||
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
|
||||
if set(tuple_fn(dilation)) != {1}:
|
||||
@ -1117,15 +1120,16 @@ max_pool3d_with_indices = _max_pool(
|
||||
|
||||
|
||||
def _avg_pool(name, tuple_fn):
|
||||
@symbolic_helper.quantized_args(True)
|
||||
@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
|
||||
def symbolic_fn(
|
||||
g,
|
||||
input,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
input: _C.Value,
|
||||
kernel_size: Tuple[int, ...],
|
||||
stride: Tuple[int, ...],
|
||||
padding: Union[int, Tuple[int, ...]],
|
||||
ceil_mode: int,
|
||||
count_include_pad: int,
|
||||
divisor_override=None,
|
||||
):
|
||||
if not stride:
|
||||
@ -1133,8 +1137,7 @@ def _avg_pool(name, tuple_fn):
|
||||
padding = symbolic_helper._avgpool_helper(
|
||||
tuple_fn, padding, kernel_size, stride, divisor_override, name
|
||||
)
|
||||
if ceil_mode:
|
||||
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
|
||||
adjusted_padding = padding
|
||||
if count_include_pad:
|
||||
input = g.op(
|
||||
"Pad",
|
||||
@ -1143,17 +1146,20 @@ def _avg_pool(name, tuple_fn):
|
||||
mode_s="constant",
|
||||
value_f=0.0,
|
||||
)
|
||||
padding = (0,) * len(padding)
|
||||
adjusted_padding = (0,) * len(padding)
|
||||
if ceil_mode:
|
||||
padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
|
||||
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
|
||||
adjusted_padding = adjusted_padding + tuple(
|
||||
a + b for (a, b) in zip(padding_ceil, adjusted_padding)
|
||||
)
|
||||
else:
|
||||
padding = padding * 2
|
||||
adjusted_padding = adjusted_padding * 2
|
||||
output = g.op(
|
||||
"AveragePool",
|
||||
input,
|
||||
kernel_shape_i=tuple_fn(kernel_size),
|
||||
strides_i=tuple_fn(stride),
|
||||
pads_i=padding,
|
||||
pads_i=adjusted_padding,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -2510,7 +2516,8 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False):
|
||||
dtype = symbolic_helper._get_const(dtype, "i", "dtype")
|
||||
if symbolic_helper._is_packed_list(data):
|
||||
if dtype is None:
|
||||
dtype = symbolic_helper._unpack_list(data)[0].type().scalarType()
|
||||
dtype = symbolic_helper._unpack_list(data)[0].type().scalarType() # type: ignore[attr-defined]
|
||||
# TODO(justinchuby): Remove type ignore after #81112 is checked in.
|
||||
dtype = symbolic_helper.scalar_type_to_onnx.index(
|
||||
symbolic_helper.cast_pytorch_to_onnx[dtype]
|
||||
)
|
||||
@ -4963,7 +4970,12 @@ class Prim:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def ListUnpack(g, *inputs, **kwargs):
|
||||
def ListUnpack(g, *inputs, **kwargs) -> Optional[List[_C.Value]]:
|
||||
if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct":
|
||||
# Cancel the previous node if it is ListConstruct by returning its inputs
|
||||
# TODO(justinchuby): Use a public method in the helper module
|
||||
return symbolic_helper._unpack_list(inputs[0])
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -12192,6 +12192,8 @@ op_db: List[OpInfo] = [
|
||||
'TestCommon', 'test_noncontiguous_samples',
|
||||
device_type='cpu'), ],
|
||||
skips=(
|
||||
# Strides are not the same!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
|
||||
# https://github.com/pytorch/pytorch/issues/67470
|
||||
DecorateInfo(unittest.skip("67470!"),
|
||||
'TestCommon', 'test_noncontiguous_samples',
|
||||
@ -13517,6 +13519,8 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
|
||||
),
|
||||
decorators=(
|
||||
# Strides are not the same!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
|
||||
DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
|
||||
'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
|
||||
)),
|
||||
|
||||
Reference in New Issue
Block a user