mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			5 Commits
		
	
	
		
			v1.12.1-rc
			...
			release/1.
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 664058fa83 | |||
| efc2d08eac | |||
| 9a9dcebfd5 | |||
| 617c4fe52c | |||
| f469bc1fe1 | 
							
								
								
									
										1
									
								
								.github/actionlint.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/actionlint.yaml
									
									
									
									
										vendored
									
									
								
							@ -12,5 +12,6 @@ self-hosted-runner:
 | 
			
		||||
    - windows.8xlarge.nvidia.gpu
 | 
			
		||||
    - bm-runner
 | 
			
		||||
    - linux.rocm.gpu
 | 
			
		||||
    - macos-12-xl
 | 
			
		||||
    - macos-12
 | 
			
		||||
    - macos12.3-m1
 | 
			
		||||
 | 
			
		||||
@ -64,7 +64,7 @@ jobs:
 | 
			
		||||
    {%- if config["package_type"] == "libtorch" %}
 | 
			
		||||
    runs-on: macos-10.15
 | 
			
		||||
    {%- else %}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    {%- endif %}
 | 
			
		||||
{%- if config["package_type"] == "libtorch" %}
 | 
			
		||||
    # libtorch builds take a long time on github hosted runners
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								.github/workflows/generated-macos-arm64-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/generated-macos-arm64-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -39,7 +39,7 @@ concurrency:
 | 
			
		||||
jobs:
 | 
			
		||||
  conda-py3_8-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -213,7 +213,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  conda-py3_9-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -387,7 +387,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  conda-py3_10-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -39,7 +39,7 @@ concurrency:
 | 
			
		||||
jobs:
 | 
			
		||||
  wheel-py3_7-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -213,7 +213,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  wheel-py3_8-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -387,7 +387,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  wheel-py3_9-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -561,7 +561,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  wheel-py3_10-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/generated-macos-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-macos-binary-conda-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -37,7 +37,7 @@ concurrency:
 | 
			
		||||
jobs:
 | 
			
		||||
  conda-py3_7-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -211,7 +211,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  conda-py3_8-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -385,7 +385,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  conda-py3_9-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -559,7 +559,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  conda-py3_10-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/generated-macos-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/generated-macos-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -37,7 +37,7 @@ concurrency:
 | 
			
		||||
jobs:
 | 
			
		||||
  wheel-py3_7-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -211,7 +211,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  wheel-py3_8-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -385,7 +385,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  wheel-py3_9-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
@ -559,7 +559,7 @@ jobs:
 | 
			
		||||
          docker system prune -af
 | 
			
		||||
  wheel-py3_10-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-12
 | 
			
		||||
    runs-on: macos-12-xl
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ if "%INSTALL_FRESH_CONDA%"=="1" (
 | 
			
		||||
 | 
			
		||||
call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Miniconda3
 | 
			
		||||
if "%INSTALL_FRESH_CONDA%"=="1" (
 | 
			
		||||
  call conda install -y -q python=%PYTHON_VERSION% numpy cffi pyyaml boto3 libuv
 | 
			
		||||
  call conda install -y -q python=%PYTHON_VERSION% numpy"<1.23" cffi pyyaml boto3 libuv
 | 
			
		||||
  if errorlevel 1 exit /b
 | 
			
		||||
  if not errorlevel 0 exit /b
 | 
			
		||||
  call conda install -y -q -c conda-forge cmake=3.22.3
 | 
			
		||||
 | 
			
		||||
@ -1712,7 +1712,8 @@ Tensor _matmul_impl(
 | 
			
		||||
  } else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
 | 
			
		||||
    return has_out ? at::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
 | 
			
		||||
  } else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
 | 
			
		||||
    return has_out ? at::mv_out(out, tensor2.t(), tensor1) : tensor2.t().mv(tensor1);
 | 
			
		||||
    return has_out ? at::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
 | 
			
		||||
                   : tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
 | 
			
		||||
  } else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
 | 
			
		||||
    return has_out ? at::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
 | 
			
		||||
  } else if (should_fold(tensor1, dim_tensor2) || should_fold(tensor2, dim_tensor1)) {
 | 
			
		||||
 | 
			
		||||
@ -123,12 +123,14 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
 | 
			
		||||
        for (int it = 0;  it < WARP_ITERATIONS;  ++it) {
 | 
			
		||||
            if (is_masked) {
 | 
			
		||||
                int idx = it*WARP_SIZE;
 | 
			
		||||
                if (!is_transformer_mask) {
 | 
			
		||||
                    idx += i*element_count;
 | 
			
		||||
                }
 | 
			
		||||
                if (!mask[idx]) {
 | 
			
		||||
                    max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
 | 
			
		||||
                    is_meaningful_max = true;
 | 
			
		||||
                if ((idx + local_idx) < element_count) {
 | 
			
		||||
                    if (!is_transformer_mask) {
 | 
			
		||||
                        idx += i*element_count;
 | 
			
		||||
                    }
 | 
			
		||||
                    if (!mask[idx]) {
 | 
			
		||||
                        max_value[i] = (is_meaningful_max && max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it];
 | 
			
		||||
                        is_meaningful_max = true;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                max_value[i] = max_value[i] > elements[i][it] ? max_value[i] : elements[i][it];
 | 
			
		||||
@ -156,22 +158,28 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                int idx = it*WARP_SIZE;
 | 
			
		||||
                bool valid = (idx + local_idx) < element_count;
 | 
			
		||||
                if (!is_transformer_mask) {
 | 
			
		||||
                    idx += i*element_count;
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                if (!mask[idx]) {
 | 
			
		||||
                    if (is_log_softmax) {
 | 
			
		||||
                        sum[i] += std::exp(elements[i][it] - max_value[i]);
 | 
			
		||||
                if (valid) {
 | 
			
		||||
                    if (!mask[idx]) {
 | 
			
		||||
                        if (is_log_softmax) {
 | 
			
		||||
                            sum[i] += std::exp(elements[i][it] - max_value[i]);
 | 
			
		||||
                        } else {
 | 
			
		||||
                            elements[i][it] = std::exp(elements[i][it] - max_value[i]);
 | 
			
		||||
                            sum[i] += elements[i][it];
 | 
			
		||||
                        }
 | 
			
		||||
                    } else {
 | 
			
		||||
                        elements[i][it] = std::exp(elements[i][it] - max_value[i]);
 | 
			
		||||
                        sum[i] += elements[i][it];
 | 
			
		||||
                        if (!is_log_softmax) {
 | 
			
		||||
                            // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
 | 
			
		||||
                            elements[i][it] = 0;
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                } else {
 | 
			
		||||
                  if (!is_log_softmax) {
 | 
			
		||||
                    // Masked values are treated as -infinity, and std::exp(-infinity) is 0.
 | 
			
		||||
                    elements[i][it] = 0;
 | 
			
		||||
                  }
 | 
			
		||||
                    if (!is_log_softmax) {
 | 
			
		||||
                        elements[i][it] = 0.;
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -1,68 +0,0 @@
 | 
			
		||||
ir_version: 3
 | 
			
		||||
producer_name: "pytorch"
 | 
			
		||||
producer_version: "CURRENT_VERSION"
 | 
			
		||||
graph {
 | 
			
		||||
  node {
 | 
			
		||||
    input: "onnx::Shape_0"
 | 
			
		||||
    output: "onnx::ConstantFill_1"
 | 
			
		||||
    name: "Shape_0"
 | 
			
		||||
    op_type: "Shape"
 | 
			
		||||
  }
 | 
			
		||||
  node {
 | 
			
		||||
    input: "onnx::ConstantFill_1"
 | 
			
		||||
    output: "2"
 | 
			
		||||
    name: "ConstantFill_1"
 | 
			
		||||
    op_type: "ConstantFill"
 | 
			
		||||
    attribute {
 | 
			
		||||
      name: "dtype"
 | 
			
		||||
      i: 1
 | 
			
		||||
      type: INT
 | 
			
		||||
    }
 | 
			
		||||
    attribute {
 | 
			
		||||
      name: "input_as_shape"
 | 
			
		||||
      i: 1
 | 
			
		||||
      type: INT
 | 
			
		||||
    }
 | 
			
		||||
    attribute {
 | 
			
		||||
      name: "value"
 | 
			
		||||
      f: 0
 | 
			
		||||
      type: FLOAT
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  name: "torch_jit"
 | 
			
		||||
  input {
 | 
			
		||||
    name: "onnx::Shape_0"
 | 
			
		||||
    type {
 | 
			
		||||
      tensor_type {
 | 
			
		||||
        elem_type: 1
 | 
			
		||||
        shape {
 | 
			
		||||
          dim {
 | 
			
		||||
            dim_value: 5
 | 
			
		||||
          }
 | 
			
		||||
          dim {
 | 
			
		||||
            dim_value: 8
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  output {
 | 
			
		||||
    name: "2"
 | 
			
		||||
    type {
 | 
			
		||||
      tensor_type {
 | 
			
		||||
        elem_type: 1
 | 
			
		||||
        shape {
 | 
			
		||||
          dim {
 | 
			
		||||
            dim_value: 5
 | 
			
		||||
          }
 | 
			
		||||
          dim {
 | 
			
		||||
            dim_value: 8
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
opset_import {
 | 
			
		||||
  version: 7
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										32
									
								
								test/onnx/symbolic_opsets/test_symbolic_opset9.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								test/onnx/symbolic_opsets/test_symbolic_opset9.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,32 @@
 | 
			
		||||
# Owner(s): ["module: onnx"]
 | 
			
		||||
 | 
			
		||||
"""Tests for `torch.onnx.symbolic_opset9`."""
 | 
			
		||||
import torch
 | 
			
		||||
from torch import _C
 | 
			
		||||
from torch.onnx import symbolic_opset9 as opset9
 | 
			
		||||
from torch.testing._internal import common_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPrim(common_utils.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        super().setUp()
 | 
			
		||||
        self.graph = _C.Graph()
 | 
			
		||||
 | 
			
		||||
    def test_list_unpack_returns_all_list_elements_when_previous_node_is_list_construct(
 | 
			
		||||
        self,
 | 
			
		||||
    ):
 | 
			
		||||
        # Build the graph
 | 
			
		||||
        input_1 = self.graph.addInput()
 | 
			
		||||
        input_1.setType(input_1.type().with_dtype(torch.float).with_sizes([2, 42]))
 | 
			
		||||
        input_2 = self.graph.addInput()
 | 
			
		||||
        input_2.setType(input_2.type().with_dtype(torch.float).with_sizes([3, 42]))
 | 
			
		||||
        constructed_list = self.graph.op("prim::ListConstruct", input_1, input_2)
 | 
			
		||||
        # Test the op
 | 
			
		||||
        outputs = opset9.Prim.ListUnpack(self.graph, constructed_list)
 | 
			
		||||
        self.assertNotEqual(outputs, None)
 | 
			
		||||
        self.assertEqual(outputs[0], input_1)
 | 
			
		||||
        self.assertEqual(outputs[1], input_2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    common_utils.run_tests()
 | 
			
		||||
@ -737,10 +737,6 @@ class TestOperators(TestCase):
 | 
			
		||||
        x = torch.randn(5, 8, requires_grad=True)
 | 
			
		||||
        self.assertONNX(lambda x: torch.empty_like(x), x)
 | 
			
		||||
 | 
			
		||||
    def test_empty_like_opset7(self):
 | 
			
		||||
        x = torch.randn(5, 8, requires_grad=True)
 | 
			
		||||
        self.assertONNX(lambda x: torch.empty_like(x), x, opset_version=7)
 | 
			
		||||
 | 
			
		||||
    def test_zeros_like(self):
 | 
			
		||||
        x = torch.randn(5, 8, requires_grad=True)
 | 
			
		||||
        self.assertONNX(lambda x: torch.zeros_like(x), x)
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@ import unittest
 | 
			
		||||
from typing import Optional, Type
 | 
			
		||||
 | 
			
		||||
import onnx
 | 
			
		||||
import onnx.numpy_helper
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import Tensor
 | 
			
		||||
@ -106,9 +107,75 @@ class TestOptionalOutput(unittest.TestCase):
 | 
			
		||||
            input_names=["y"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_maintain_dynamic_shapes_of_unreliable_nodes(self):
 | 
			
		||||
        def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs):
 | 
			
		||||
            return g.op("com.microsoft::PythonOp")
 | 
			
		||||
 | 
			
		||||
        torch.onnx.register_custom_op_symbolic("prim::PythonOp", symbolic_pythonop, 1)
 | 
			
		||||
        self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "prim::PythonOp", 1)
 | 
			
		||||
 | 
			
		||||
        # necessay parameters for transformer embeddings
 | 
			
		||||
        hidden_size = 48
 | 
			
		||||
        max_position_embeddings = 32
 | 
			
		||||
        batch_size = 2
 | 
			
		||||
 | 
			
		||||
        # issue found that autograd.function making downstream
 | 
			
		||||
        # node unreliable but with static shape. The issue was first
 | 
			
		||||
        # discovered with using Apex FusedLayerNorm in Transformers
 | 
			
		||||
        class CustomLayerNorm(torch.autograd.Function):
 | 
			
		||||
            @staticmethod
 | 
			
		||||
            def forward(ctx, embedding):
 | 
			
		||||
                layer_norm = torch.nn.LayerNorm(hidden_size, eps=1e-12)
 | 
			
		||||
                return layer_norm(embedding)
 | 
			
		||||
 | 
			
		||||
        class EmbeddingModule(torch.nn.Module):
 | 
			
		||||
            def forward(
 | 
			
		||||
                self,
 | 
			
		||||
                embeddings=None,
 | 
			
		||||
            ):
 | 
			
		||||
                embedding_output = CustomLayerNorm.apply(embeddings)
 | 
			
		||||
                query = embedding_output.transpose(0, 1)
 | 
			
		||||
                target_len, batch_size, embedding_dim = query.size()
 | 
			
		||||
                # Reshape is used for consuming batch_size, and if it is static,
 | 
			
		||||
                # this will be a Constant node in the graph
 | 
			
		||||
                query = query.reshape(target_len, batch_size, embedding_dim)
 | 
			
		||||
                return query
 | 
			
		||||
 | 
			
		||||
        embeddings = torch.randn(batch_size, max_position_embeddings, hidden_size)
 | 
			
		||||
 | 
			
		||||
        f = io.BytesIO()
 | 
			
		||||
        torch.onnx.export(
 | 
			
		||||
            EmbeddingModule().eval(),
 | 
			
		||||
            (embeddings,),
 | 
			
		||||
            f,
 | 
			
		||||
            input_names=["embeddings"],
 | 
			
		||||
            dynamic_axes={
 | 
			
		||||
                "embeddings": {
 | 
			
		||||
                    0: "batch_size",
 | 
			
		||||
                    1: "max_position_embeddings",
 | 
			
		||||
                    2: "hidden_size",
 | 
			
		||||
                }
 | 
			
		||||
            },
 | 
			
		||||
            custom_opsets={"com.microsoft": 1},
 | 
			
		||||
        )
 | 
			
		||||
        model = onnx.load(io.BytesIO(f.getvalue()))
 | 
			
		||||
 | 
			
		||||
        # If there is a constant node with dim=3 and max_position_embeddings,
 | 
			
		||||
        # batch_size, hidden_size as shape, it means the shape becomes static.
 | 
			
		||||
        # Normally, with dynamic batch size, this constant node should not exist.
 | 
			
		||||
        const_node = [n for n in model.graph.node if n.op_type == "Constant"]
 | 
			
		||||
        self.assertNotEqual(len(const_node), 0)
 | 
			
		||||
        for node in const_node:
 | 
			
		||||
            for a in node.attribute:
 | 
			
		||||
                if a.name == "value":
 | 
			
		||||
                    shape = onnx.numpy_helper.to_array(a.t)
 | 
			
		||||
                    self.assertNotEqual(
 | 
			
		||||
                        shape.tolist(),
 | 
			
		||||
                        [max_position_embeddings, batch_size, hidden_size],
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
instantiate_parametrized_tests(TestOptionalOutput)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
 | 
			
		||||
@ -3456,28 +3456,68 @@ class _TestONNXRuntime:
 | 
			
		||||
        self.run_test(model, x)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)
 | 
			
		||||
    def test_listunpack(self):
 | 
			
		||||
        class ListUnpack(torch.jit.ScriptModule):
 | 
			
		||||
            @torch.jit.script_method
 | 
			
		||||
    def test_list_unpack_scripted(self):
 | 
			
		||||
        class ListUnpack(torch.nn.Module):
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                a, b = x.shape
 | 
			
		||||
                return x.new_zeros((a, b))
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 3)
 | 
			
		||||
        self.run_test(ListUnpack(), x, input_names=["x"], dynamic_axes={"x": [0, 1]})
 | 
			
		||||
        self.run_test(ListUnpack(), x, remained_onnx_input_idx=[])
 | 
			
		||||
        self.run_test(
 | 
			
		||||
            torch.jit.script(ListUnpack()),
 | 
			
		||||
            x,
 | 
			
		||||
            input_names=["x"],
 | 
			
		||||
            dynamic_axes={"x": [0, 1]},
 | 
			
		||||
        )
 | 
			
		||||
        self.run_test(torch.jit.script(ListUnpack()), x, remained_onnx_input_idx=[])
 | 
			
		||||
 | 
			
		||||
        class ListUnpackSlice(torch.jit.ScriptModule):
 | 
			
		||||
            @torch.jit.script_method
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)
 | 
			
		||||
    def test_list_unpack_scripted_runs_without_error_with_constructed_list_as_input(
 | 
			
		||||
        self,
 | 
			
		||||
    ):
 | 
			
		||||
        class PackUnpack(torch.nn.Module):
 | 
			
		||||
            """Create and unpack a list of tensors.
 | 
			
		||||
 | 
			
		||||
            When scripted, it should produce a graph similar to
 | 
			
		||||
 | 
			
		||||
            ```
 | 
			
		||||
            graph(%self : __torch__.PackUnpack,
 | 
			
		||||
                %a.1 : Tensor,
 | 
			
		||||
                %b.1 : Tensor):
 | 
			
		||||
            %packed.1 : Tensor[] = prim::ListConstruct(%a.1, %b.1)
 | 
			
		||||
            %c.1 : Tensor, %8 : Tensor = prim::ListUnpack(%packed.1)
 | 
			
		||||
            return (%c.1)
 | 
			
		||||
            ```
 | 
			
		||||
            """
 | 
			
		||||
 | 
			
		||||
            def forward(self, a, b):
 | 
			
		||||
                packed = [a, b]
 | 
			
		||||
                c, _ = packed
 | 
			
		||||
                return c
 | 
			
		||||
 | 
			
		||||
        self.run_test(
 | 
			
		||||
            torch.jit.script(PackUnpack()),
 | 
			
		||||
            (torch.tensor(0), torch.tensor([42])),
 | 
			
		||||
            remained_onnx_input_idx=[0],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)
 | 
			
		||||
    def test_list_unpack_slice_scripted(self):
 | 
			
		||||
        class ListUnpackSlice(torch.nn.Module):
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                a, b = x.shape[2:]
 | 
			
		||||
                return x.new_zeros((a, b))
 | 
			
		||||
 | 
			
		||||
        x = torch.randn(2, 3, 4, 5)
 | 
			
		||||
        self.run_test(
 | 
			
		||||
            ListUnpackSlice(), x, input_names=["x"], dynamic_axes={"x": [0, 1, 2, 3]}
 | 
			
		||||
            torch.jit.script(ListUnpackSlice()),
 | 
			
		||||
            x,
 | 
			
		||||
            input_names=["x"],
 | 
			
		||||
            dynamic_axes={"x": [0, 1, 2, 3]},
 | 
			
		||||
        )
 | 
			
		||||
        self.run_test(
 | 
			
		||||
            torch.jit.script(ListUnpackSlice()), x, remained_onnx_input_idx=[]
 | 
			
		||||
        )
 | 
			
		||||
        self.run_test(ListUnpackSlice(), x, remained_onnx_input_idx=[])
 | 
			
		||||
 | 
			
		||||
    def test_pow(self):
 | 
			
		||||
        class PowModule(torch.nn.Module):
 | 
			
		||||
@ -12366,6 +12406,13 @@ class _TestONNXRuntime:
 | 
			
		||||
        q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
 | 
			
		||||
        self.run_test(model, q_input)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(10)
 | 
			
		||||
    def test_quantized_sigmoid(self):
 | 
			
		||||
        model = torch.nn.Sigmoid()
 | 
			
		||||
        input = torch.randn(2, 6)
 | 
			
		||||
        q_input = torch.quantize_per_tensor(input, 0.26, 128, torch.quint8)
 | 
			
		||||
        self.run_test(model, q_input)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(10)
 | 
			
		||||
    def test_quantized_flatten(self):
 | 
			
		||||
        class FlattenModel(torch.nn.Module):
 | 
			
		||||
@ -12375,6 +12422,19 @@ class _TestONNXRuntime:
 | 
			
		||||
        x = torch.quantize_per_tensor(torch.randn(1, 2, 3, 4), 1, 0, torch.quint8)
 | 
			
		||||
        self.run_test(FlattenModel(), x)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip(
 | 
			
		||||
        "ONNX Runtime 1.11 does not support quantized cat. Enable after ORT 1.12 is enabled in CI."
 | 
			
		||||
    )
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(10)
 | 
			
		||||
    @skipScriptTest()  # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
 | 
			
		||||
    def test_quantized_cat(self):
 | 
			
		||||
        class QuantizedConcatenationModel(torch.nn.Module):
 | 
			
		||||
            def forward(self, x):
 | 
			
		||||
                return torch.nn.quantized.QFunctional().cat((x, x), dim=1)
 | 
			
		||||
 | 
			
		||||
        q_input = torch.quantize_per_tensor(torch.ones(2, 3), 0.26, 128, torch.quint8)
 | 
			
		||||
        self.run_test(QuantizedConcatenationModel(), q_input)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(10)
 | 
			
		||||
    @skipScriptTest()  # torch.jit.frontend.FrontendError: Cannot instantiate class 'QFunctional' in a script function:
 | 
			
		||||
    def test_quantized_arithmetic_qfunctional(self):
 | 
			
		||||
@ -12596,6 +12656,32 @@ class _TestONNXRuntime:
 | 
			
		||||
        input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
 | 
			
		||||
        self.run_test(model, input)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(10)
 | 
			
		||||
    def test_qat_avg_pool2d(self):
 | 
			
		||||
        model = torch.nn.Sequential(
 | 
			
		||||
            torch.quantization.QuantStub(),
 | 
			
		||||
            torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1),
 | 
			
		||||
            torch.quantization.DeQuantStub(),
 | 
			
		||||
        )
 | 
			
		||||
        model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
 | 
			
		||||
        model = torch.quantization.prepare_qat(model.train())
 | 
			
		||||
        model = torch.quantization.convert(model)
 | 
			
		||||
        input = _construct_tensor_for_quantization_test((4, 4, 3, 2))
 | 
			
		||||
        self.run_test(model, input)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(11)
 | 
			
		||||
    def test_qat_upsample_nearest2d(self):
 | 
			
		||||
        model = torch.nn.Sequential(
 | 
			
		||||
            torch.quantization.QuantStub(),
 | 
			
		||||
            torch.nn.UpsamplingNearest2d(scale_factor=1.5),
 | 
			
		||||
            torch.quantization.DeQuantStub(),
 | 
			
		||||
        )
 | 
			
		||||
        model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
 | 
			
		||||
        model = torch.quantization.prepare_qat(model.train())
 | 
			
		||||
        model = torch.quantization.convert(model)
 | 
			
		||||
        input = _construct_tensor_for_quantization_test((4, 3, 2, 2))
 | 
			
		||||
        self.run_test(model, input)
 | 
			
		||||
 | 
			
		||||
    @skipIfUnsupportedMinOpsetVersion(9)
 | 
			
		||||
    def test_convolution_allow_tf32(self):
 | 
			
		||||
        class Module(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
@ -451,11 +451,34 @@ class _InsertPoint:
 | 
			
		||||
    def __enter__(self) -> None: ...
 | 
			
		||||
    def __exit__(self, *args) -> None: ...
 | 
			
		||||
 | 
			
		||||
# Defined in torch/csrc/jit/ir/ir.h
 | 
			
		||||
class Use:
 | 
			
		||||
    @property
 | 
			
		||||
    def user(self) -> Node: ...
 | 
			
		||||
    @property
 | 
			
		||||
    def offset(self) -> _int: ...
 | 
			
		||||
    def isAfter(self, other: Use) -> _bool: ...
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
# Defined in torch/csrc/jit/ir/ir.h
 | 
			
		||||
class Value:
 | 
			
		||||
    def type(self)-> JitType: ...
 | 
			
		||||
    def setType(self, t: JitType) -> Value: ...
 | 
			
		||||
    def setTypeAs(self, other: Value) -> Value: ...
 | 
			
		||||
    def inferTypeFrom(self, t: Tensor) -> None: ...
 | 
			
		||||
    def debugName(self) -> str: ...
 | 
			
		||||
    def setDebugName(self, name: str) -> None: ...
 | 
			
		||||
    def unique(self) -> _int: ...
 | 
			
		||||
    def offset(self) -> _int: ...
 | 
			
		||||
    def node(self) -> Node: ...
 | 
			
		||||
    def uses(self) -> List[Use]: ...
 | 
			
		||||
    def replaceAllUsesWith(self, val: Value) -> None: ...
 | 
			
		||||
    def replaceAllUsesAfterNodeWith(self, node: Node, val: Value) -> None: ...
 | 
			
		||||
    def requires_grad(self) -> _bool: ...
 | 
			
		||||
    def requiresGrad(self) -> _bool: ...
 | 
			
		||||
    def copyMetadata(self, other: Value) -> Value: ...
 | 
			
		||||
    def isCompleteTensor(self) -> _bool: ...
 | 
			
		||||
    def toIValue(self) -> IValue: ...
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
# Defined in torch/csrc/jit/ir/ir.h
 | 
			
		||||
@ -467,14 +490,80 @@ class Block:
 | 
			
		||||
# Defined in torch/csrc/jit/ir/ir.h
 | 
			
		||||
class Node:
 | 
			
		||||
    def schema(self) -> str: ...
 | 
			
		||||
    def input(self) -> Value: ...
 | 
			
		||||
    def inputs(self) -> List[Value]: ...
 | 
			
		||||
    def inputsAt(self, idx: _int) -> Value: ...
 | 
			
		||||
    def inputsSize(self) -> _int: ...
 | 
			
		||||
    def output(self) -> Value: ...
 | 
			
		||||
    def outputs(self) -> List[Value]: ...
 | 
			
		||||
    def outputsAt(self, idx: _int) -> Value: ...
 | 
			
		||||
    def outputsSize(self) -> _int: ...
 | 
			
		||||
    def hasMultipleOutputs(self) -> _bool: ...
 | 
			
		||||
    def blocks(self) -> List[Block]: ...
 | 
			
		||||
    def mustBeNone(self) -> _bool: ...
 | 
			
		||||
    def kindOf(self, str) -> str: ...
 | 
			
		||||
    def __getitem__(self, key: str) -> Any: ...
 | 
			
		||||
    def namedInput(self, str) -> Value: ...
 | 
			
		||||
    def matches(self, pattern: str) -> _bool: ...
 | 
			
		||||
    def kind(self) -> str: ...
 | 
			
		||||
    def kindOf(self, name: str) -> str: ...
 | 
			
		||||
    def addInput(self, name: str) -> Value: ...
 | 
			
		||||
    def replaceInput(self, i: _int, newValue: Value) -> Value: ...
 | 
			
		||||
    def replaceInputWith(self, from_: Value, to: Value) -> None: ...
 | 
			
		||||
    def replaceAllUsesWith(self, n: Node) -> None: ...
 | 
			
		||||
    def insertBefore(self, n: Node) -> Node: ...
 | 
			
		||||
    def insertAfter(self, n: Node) -> Node: ...
 | 
			
		||||
    def isBefore(self, n: Node) -> _bool: ...
 | 
			
		||||
    def isAfter(self, n: Node) -> _bool: ...
 | 
			
		||||
    def moveBefore(self, n: Node) -> None: ...
 | 
			
		||||
    def moveAfter(self, n: Node) -> None: ...
 | 
			
		||||
    def removeInput(self, i: _int) -> None: ...
 | 
			
		||||
    def removeAllInputs(self, i: _int) -> None: ...
 | 
			
		||||
    def hasUses(self) -> _bool: ...
 | 
			
		||||
    def eraseOutput(self, i: _int) -> None: ...
 | 
			
		||||
    def addOutput(self) -> Value: ...
 | 
			
		||||
    def scopeName(self) -> str: ...
 | 
			
		||||
    def isNondeterministic(self) -> _bool: ...
 | 
			
		||||
    def copyAttributes(self, rhs: Node) -> Node: ...
 | 
			
		||||
    def copyMetadata(self, rhs: Node) -> Node: ...
 | 
			
		||||
    def hasAttributes(self) -> _bool: ...
 | 
			
		||||
    def hasAttribute(self, name: str) -> _bool: ...
 | 
			
		||||
    def removeAttribute(self, attr: str) -> Node: ...
 | 
			
		||||
    def namedInput(self, name: str) -> Value: ...
 | 
			
		||||
    def sourceRange(self) -> SourceRange: ...
 | 
			
		||||
    def owningBlock(self) -> Block: ...
 | 
			
		||||
    def findNode(self, kind: str, recurse: _bool = True) -> Node: ...
 | 
			
		||||
    def findAllNodes(self, kind: str, recurse: _bool = True) -> List[Node]: ...
 | 
			
		||||
    def getModuleHierarchy(self) -> str: ...
 | 
			
		||||
    def prev(self) -> Node: ...
 | 
			
		||||
    def destroy(self) -> None: ...
 | 
			
		||||
 | 
			
		||||
    # Accessors for attributes as types.
 | 
			
		||||
    def f(self, name: str) -> _float: ...
 | 
			
		||||
    def f_(self, name: str, val: _float) -> Node: ...
 | 
			
		||||
    def fs(self, name: str) -> List[_float]: ...
 | 
			
		||||
    def fs_(self, name: str, val: List[_float]) -> Node: ...
 | 
			
		||||
    def c(self, name: str) -> complex: ...
 | 
			
		||||
    def c_(self, name: str, val: complex) -> Node: ...
 | 
			
		||||
    def s(self, name: str) -> str: ...
 | 
			
		||||
    def s_(self, name: str, val: str) -> Node: ...
 | 
			
		||||
    def ss(self, name: str) -> List[str]: ...
 | 
			
		||||
    def ss_(self, name: str, val: List[str]) -> Node: ...
 | 
			
		||||
    def i(self, name: str) -> _int: ...
 | 
			
		||||
    def i_(self, name: str, val: _int) -> Node: ...
 | 
			
		||||
    # Cannot define "is" like this because it's a reserved keyword in python.
 | 
			
		||||
    # def is(self, name: str) -> List[_int]: ...
 | 
			
		||||
    # def is_(self, name: str, val: List[_int]) -> Node: ...
 | 
			
		||||
    def g(self, name: str) -> Graph: ...
 | 
			
		||||
    def g_(self, name: str, val: Graph) -> Node: ...
 | 
			
		||||
    def gs(self, name: str) -> List[Graph]: ...
 | 
			
		||||
    def gs_(self, name: str, val: List[Graph]) -> Node: ...
 | 
			
		||||
    def ival(self, name: str) -> IValue: ...
 | 
			
		||||
    def ival_(self, name: str, val: IValue) -> Node: ...
 | 
			
		||||
    def t(self, name: str) -> Tensor: ...
 | 
			
		||||
    def t_(self, name: str, val: Tensor) -> Node: ...
 | 
			
		||||
    def ts(self, name: str) -> List[Tensor]: ...
 | 
			
		||||
    def ts_(self, name: str, val: List[Tensor]) -> Node: ...
 | 
			
		||||
    def ty_(self, name: str, val: JitType) -> Node: ...
 | 
			
		||||
    def tys_(self, name: str, val: List[JitType]) -> Node: ...
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
# Defined in torch/torch/csrc/jit/ir/ir.h
 | 
			
		||||
 | 
			
		||||
@ -1562,6 +1562,12 @@ void ProcessConstantValueMap(Node* n, int opset_version) {
 | 
			
		||||
  // For outputs, only update static shapes. For input, we update symbolic
 | 
			
		||||
  // shapes also. ONNX If can have different types on different branches, skip
 | 
			
		||||
  // here.
 | 
			
		||||
 | 
			
		||||
  // Update the shape reliability for each node before processing
 | 
			
		||||
  // ConstantValueMap to prevent unreliable nodes from producing static
 | 
			
		||||
  // shapes
 | 
			
		||||
  UpdateReliable(n);
 | 
			
		||||
 | 
			
		||||
  auto static_input_shape = AllGraphInputsStatic(n->owningGraph());
 | 
			
		||||
  for (auto i : c10::irange(n->outputs().size())) {
 | 
			
		||||
    if (TensorTypePtr output_type = n->output(i)->type()->cast<TensorType>()) {
 | 
			
		||||
 | 
			
		||||
@ -82,5 +82,7 @@ void UpdateReliable(
 | 
			
		||||
    torch::jit::Value* output,
 | 
			
		||||
    const std::pair<bool, bool>& input_reliable);
 | 
			
		||||
 | 
			
		||||
void UpdateReliable(torch::jit::Node* n);
 | 
			
		||||
 | 
			
		||||
} // namespace jit
 | 
			
		||||
} // namespace torch
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ import torch._C._onnx as _C_onnx
 | 
			
		||||
from torch.onnx._globals import GLOBALS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(#78694): Refactor the patching process to make it more transparent to users.
 | 
			
		||||
def _graph_op(
 | 
			
		||||
    g: torch._C.Graph,
 | 
			
		||||
    opname: str,
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,11 @@
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import enum
 | 
			
		||||
import functools
 | 
			
		||||
import inspect
 | 
			
		||||
import sys
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Set
 | 
			
		||||
from typing import Any, Callable, List, Optional, Sequence, Set, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch._C._onnx as _C_onnx
 | 
			
		||||
@ -140,7 +142,7 @@ def _get_const(value, desc, arg_name):
 | 
			
		||||
    return _parse_arg(value, desc)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _unpack_list(list_value):
 | 
			
		||||
def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
 | 
			
		||||
    list_node = list_value.node()
 | 
			
		||||
    assert list_node.kind() == "prim::ListConstruct"
 | 
			
		||||
    return list(list_node.inputs())
 | 
			
		||||
@ -236,15 +238,19 @@ def parse_args(*arg_descriptors):
 | 
			
		||||
    return decorator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
 | 
			
		||||
def quantized_args(
 | 
			
		||||
    *arg_q_descriptors: bool,
 | 
			
		||||
    scale: Optional[float] = None,
 | 
			
		||||
    zero_point: Optional[int] = None,
 | 
			
		||||
):
 | 
			
		||||
    """A decorator which extends support for quantized version of the base operator.
 | 
			
		||||
    Quantization is detected by examining the arguments that are annotated by
 | 
			
		||||
    `arg_q_descriptors`.
 | 
			
		||||
 | 
			
		||||
    If quantization is detected, the base operator symbolic function will be wrapped with
 | 
			
		||||
    argument dequantization and output quantization.
 | 
			
		||||
    argument de-quantization and output quantization.
 | 
			
		||||
 | 
			
		||||
    Otherwise, only base symbolic function will be invoked.
 | 
			
		||||
    Otherwise, only the base symbolic function will be invoked.
 | 
			
		||||
 | 
			
		||||
    For example:
 | 
			
		||||
 | 
			
		||||
@ -267,11 +273,12 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
 | 
			
		||||
    ```
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        arg_q_descriptors: list of bool, where each element represents if the
 | 
			
		||||
          argument is QTensor for quantized version of this operator.
 | 
			
		||||
        scale: float default None, quantized output scale. If None, derive from
 | 
			
		||||
        arg_q_descriptors: A sequence of bool, where each element represents if the
 | 
			
		||||
          argument is QTensor for quantized version of this operator. It defaults
 | 
			
		||||
          to False for unspecified (variable length) arguments.
 | 
			
		||||
        scale: Quantized output scale. If None, derive from
 | 
			
		||||
          the first quantized input scale.
 | 
			
		||||
        zero_point: int default None, quantized output zero point. If None,
 | 
			
		||||
        zero_point: Quantized output zero point. If None,
 | 
			
		||||
          derive from the first quantized input zero point.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -288,19 +295,21 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
 | 
			
		||||
            if _zero_point is not None:
 | 
			
		||||
                _zero_point = g.op("Constant", value_t=torch.tensor(_zero_point))
 | 
			
		||||
 | 
			
		||||
            # some args may be optional, so the length may be smaller
 | 
			
		||||
            assert len(arg_q_descriptors) >= len(args)
 | 
			
		||||
            desc_args = tuple(zip(arg_q_descriptors[: len(args)], args))
 | 
			
		||||
            # Support variable length arguments by marking unspecified ones as non-quantized
 | 
			
		||||
            arg_q_descriptors_extended = arg_q_descriptors + (False,) * (
 | 
			
		||||
                len(args) - len(arg_q_descriptors)
 | 
			
		||||
            )
 | 
			
		||||
            descriptor_args = tuple(zip(arg_q_descriptors_extended, args))
 | 
			
		||||
            # Run regular symbolic function if none of the argument is QTensor.
 | 
			
		||||
            if not any(
 | 
			
		||||
                (desc and arg.node().kind() == "prim::TupleConstruct")
 | 
			
		||||
                for desc, arg in desc_args
 | 
			
		||||
                (descriptor and arg.node().kind() == "prim::TupleConstruct")
 | 
			
		||||
                for descriptor, arg in descriptor_args
 | 
			
		||||
            ):
 | 
			
		||||
                return fn(g, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
            dequantized_args = []
 | 
			
		||||
            for desc, arg in desc_args:
 | 
			
		||||
                if desc:
 | 
			
		||||
            for descriptor, arg in descriptor_args:
 | 
			
		||||
                if descriptor:
 | 
			
		||||
                    dequantized_arg, scale, zero_point, _ = dequantize_helper(g, arg)
 | 
			
		||||
                    dequantized_args.append(dequantized_arg)
 | 
			
		||||
                    if _scale is None:
 | 
			
		||||
@ -309,7 +318,8 @@ def quantized_args(*arg_q_descriptors, scale=None, zero_point=None):
 | 
			
		||||
                        _zero_point = zero_point
 | 
			
		||||
                else:
 | 
			
		||||
                    dequantized_args.append(arg)
 | 
			
		||||
            # TODO: only support single output
 | 
			
		||||
            # TODO(justinchuby): Only single output is supported for now. We may want to
 | 
			
		||||
            # support multiple outputs in the future.
 | 
			
		||||
            output = fn(g, *dequantized_args, **kwargs)
 | 
			
		||||
 | 
			
		||||
            return quantize_helper(g, output, _scale, _zero_point)
 | 
			
		||||
@ -786,6 +796,7 @@ def _interpolate_get_scales_and_mode(g, input, size, scale_factor, mode, align_c
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _interpolate_helper(name, dim, interpolate_mode):
 | 
			
		||||
    @quantized_args(True, False, False)
 | 
			
		||||
    def symbolic_fn(g, input, output_size, *args):
 | 
			
		||||
        scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args)
 | 
			
		||||
        align_corners = _maybe_get_scalar(align_corners)
 | 
			
		||||
@ -1113,13 +1124,17 @@ def _batchnorm_helper(g, input, weight, bias, running_mean, running_var):
 | 
			
		||||
    return weight, bias, running_mean, running_var
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name):
 | 
			
		||||
def _avgpool_helper(
 | 
			
		||||
    tuple_fn: Callable[[Any], Sequence[int]],
 | 
			
		||||
    padding: Union[int, Sequence[int]],
 | 
			
		||||
    kernel_size,
 | 
			
		||||
    stride,
 | 
			
		||||
    divisor_override,
 | 
			
		||||
    name,
 | 
			
		||||
) -> Tuple[int, ...]:
 | 
			
		||||
    if divisor_override and divisor_override.node().kind() != "prim::Constant":
 | 
			
		||||
        return _unimplemented(name, "divisor_override")
 | 
			
		||||
    if not stride:
 | 
			
		||||
        stride = kernel_size
 | 
			
		||||
    padding = tuple(tuple_fn(padding))
 | 
			
		||||
    return padding
 | 
			
		||||
        _unimplemented(name, "divisor_override")
 | 
			
		||||
    return tuple(tuple_fn(padding))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_training_mode(op_train_mode, op_name):
 | 
			
		||||
@ -1193,7 +1208,11 @@ def _handle_reduce_dim_none(g, self, op_name):
 | 
			
		||||
    return g.op(op_name, self, keepdims_i=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dequantize_helper(g, qtensor, qdtype=None):
 | 
			
		||||
def dequantize_helper(
 | 
			
		||||
    g,
 | 
			
		||||
    qtensor: _C.Value,
 | 
			
		||||
    qdtype: Optional[torch.onnx.TensorProtoDataType] = None,
 | 
			
		||||
) -> Tuple[_C.Value, _C.Value, _C.Value, Optional[_C.Value]]:
 | 
			
		||||
    """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
@ -1234,7 +1253,13 @@ def dequantize_helper(g, qtensor, qdtype=None):
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def quantize_helper(g, tensor, scale, zero_point, axis=None):
 | 
			
		||||
def quantize_helper(
 | 
			
		||||
    g,
 | 
			
		||||
    tensor: _C.Value,
 | 
			
		||||
    scale: _C.Value,
 | 
			
		||||
    zero_point: _C.Value,
 | 
			
		||||
    axis: Optional[_C.Value] = None,
 | 
			
		||||
) -> _C.Value:
 | 
			
		||||
    """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
@ -1258,11 +1283,13 @@ def quantize_helper(g, tensor, scale, zero_point, axis=None):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    assert scale is not None
 | 
			
		||||
    if scale.type().scalarType() != "Float":
 | 
			
		||||
    if scale.type().scalarType() != "Float":  # type: ignore[attr-defined]
 | 
			
		||||
        # TODO(justinchuby): Remove type ignore after #81112 is checked in.
 | 
			
		||||
        scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
 | 
			
		||||
 | 
			
		||||
    assert zero_point is not None
 | 
			
		||||
    if zero_point.type().scalarType() not in ("Byte", "Char"):
 | 
			
		||||
    if zero_point.type().scalarType() not in ("Byte", "Char"):  # type: ignore[attr-defined]
 | 
			
		||||
        # TODO(justinchuby): Remove type ignore after #81112 is checked in.
 | 
			
		||||
        zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
 | 
			
		||||
    output = g.op(
 | 
			
		||||
        "QuantizeLinear",
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,11 @@
 | 
			
		||||
import sys
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Sequence
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch._C._onnx as _C_onnx
 | 
			
		||||
import torch.onnx
 | 
			
		||||
from torch import _C
 | 
			
		||||
 | 
			
		||||
# This import monkey-patches graph manipulation methods on Graph, used for the
 | 
			
		||||
# ONNX symbolics
 | 
			
		||||
@ -141,15 +143,16 @@ max_pool3d_with_indices = _max_pool(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _avg_pool(name, tuple_fn):
 | 
			
		||||
    @symbolic_helper.quantized_args(True, False, False, False, False, False, False)
 | 
			
		||||
    @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
 | 
			
		||||
    def symbolic_fn(
 | 
			
		||||
        g,
 | 
			
		||||
        input,
 | 
			
		||||
        kernel_size,
 | 
			
		||||
        stride,
 | 
			
		||||
        padding,
 | 
			
		||||
        ceil_mode,
 | 
			
		||||
        count_include_pad,
 | 
			
		||||
        input: _C.Value,
 | 
			
		||||
        kernel_size: Sequence[int],
 | 
			
		||||
        stride: Sequence[int],
 | 
			
		||||
        padding: Sequence[int],
 | 
			
		||||
        ceil_mode: int,
 | 
			
		||||
        count_include_pad: int,
 | 
			
		||||
        divisor_override=None,
 | 
			
		||||
    ):
 | 
			
		||||
        if not stride:
 | 
			
		||||
@ -187,6 +190,7 @@ avg_pool3d = _avg_pool("avg_pool3d", torch.nn.modules.utils._triple)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _interpolate(name, dim, interpolate_mode):
 | 
			
		||||
    @symbolic_helper.quantized_args(True, False, False)
 | 
			
		||||
    def symbolic_fn(g, input, output_size, *args):
 | 
			
		||||
        scales, align_corners = symbolic_helper._get_interpolate_attributes(
 | 
			
		||||
            g, interpolate_mode, args
 | 
			
		||||
@ -543,6 +547,16 @@ class Quantized:
 | 
			
		||||
 | 
			
		||||
        return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def add_relu(g, x, y, op_scale, op_zero_point):
 | 
			
		||||
        x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
 | 
			
		||||
        y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
 | 
			
		||||
 | 
			
		||||
        output = opset9.add(g, x, y)
 | 
			
		||||
        output = opset9.relu(g, output)
 | 
			
		||||
 | 
			
		||||
        return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def mul(g, x, y, op_scale, op_zero_point):
 | 
			
		||||
        x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
 | 
			
		||||
@ -612,3 +626,19 @@ class Quantized:
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @symbolic_helper.parse_args("v", "i", "v", "v")
 | 
			
		||||
    def cat(
 | 
			
		||||
        g,
 | 
			
		||||
        q_inputs: _C.Value,
 | 
			
		||||
        dim: int,
 | 
			
		||||
        op_scale: _C.Value,
 | 
			
		||||
        op_zero_point: _C.Value,
 | 
			
		||||
    ) -> _C.Value:
 | 
			
		||||
        unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
 | 
			
		||||
        dequantized = [
 | 
			
		||||
            symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
 | 
			
		||||
        ]
 | 
			
		||||
        concatenated = g.op("Concat", *dequantized, axis_i=dim)
 | 
			
		||||
        return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,11 @@
 | 
			
		||||
"""This file exports ONNX ops for opset 11."""
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import _C
 | 
			
		||||
from torch.onnx import symbolic_helper
 | 
			
		||||
from torch.onnx import symbolic_opset9 as opset9
 | 
			
		||||
from torch.onnx import symbolic_opset10 as opset10
 | 
			
		||||
@ -143,7 +147,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
 | 
			
		||||
 | 
			
		||||
    if len(indices_list) > 1:
 | 
			
		||||
        for idx_ in range(len(indices_list)):
 | 
			
		||||
            if indices_list[idx_].type().scalarType() == "Bool":
 | 
			
		||||
            if indices_list[idx_].type().scalarType() == "Bool":  # type: ignore[attr-defined]
 | 
			
		||||
                # TODO(justinchuby): Remove type ignore after #81112 is checked in.
 | 
			
		||||
                indices_list[idx_] = g.op("NonZero", indices_list[idx_])
 | 
			
		||||
        index = indices_list[0]
 | 
			
		||||
 | 
			
		||||
@ -198,7 +203,8 @@ def index_put(g, self, indices_list_value, values, accumulate=False):
 | 
			
		||||
        #   return (%33)
 | 
			
		||||
        index = indices_list[0]
 | 
			
		||||
        bool_inp = index
 | 
			
		||||
        if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":
 | 
			
		||||
        if bool_inp.type() is not None and bool_inp.type().scalarType() == "Bool":  # type: ignore[attr-defined]
 | 
			
		||||
            # TODO(justinchuby): Remove type ignore after #81112 is checked in.
 | 
			
		||||
            rank = symbolic_helper._get_tensor_rank(values)
 | 
			
		||||
            if rank is not None and rank == 0:
 | 
			
		||||
                return opset9.masked_fill(g, self, bool_inp, values)
 | 
			
		||||
@ -258,6 +264,7 @@ upsample_trilinear3d = _interpolate("upsample_trilinear3d", 5, "linear")
 | 
			
		||||
upsample_bicubic2d = _interpolate("upsample_bicubic2d", 4, "cubic")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@symbolic_helper.quantized_args(True, False, False, False, False, False, False)
 | 
			
		||||
def __interpolate(
 | 
			
		||||
    g, input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
 | 
			
		||||
):
 | 
			
		||||
@ -415,15 +422,16 @@ def _unique2(g, self, sorted, return_inverse, return_counts):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _avg_pool(name, tuple_fn):
 | 
			
		||||
    @symbolic_helper.quantized_args(True, False, False, False, False, False, False)
 | 
			
		||||
    @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
 | 
			
		||||
    def symbolic_fn(
 | 
			
		||||
        g,
 | 
			
		||||
        input,
 | 
			
		||||
        kernel_size,
 | 
			
		||||
        stride,
 | 
			
		||||
        padding,
 | 
			
		||||
        ceil_mode,
 | 
			
		||||
        count_include_pad,
 | 
			
		||||
        input: _C.Value,
 | 
			
		||||
        kernel_size: Tuple[int, ...],
 | 
			
		||||
        stride: Tuple[int, ...],
 | 
			
		||||
        padding: Union[int, Tuple[int, ...]],
 | 
			
		||||
        ceil_mode: int,
 | 
			
		||||
        count_include_pad: int,
 | 
			
		||||
        divisor_override=None,
 | 
			
		||||
    ):
 | 
			
		||||
        padding = symbolic_helper._avgpool_helper(
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ import functools
 | 
			
		||||
import math
 | 
			
		||||
import sys
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch._C._onnx as _C_onnx
 | 
			
		||||
@ -361,6 +361,8 @@ def atan(g, self):
 | 
			
		||||
    return g.op("Atan", self)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Fixed scale and zero_point, discovered from aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
 | 
			
		||||
@symbolic_helper.quantized_args(True, scale=1.0 / 256.0, zero_point=0)
 | 
			
		||||
def sigmoid(g, self):
 | 
			
		||||
    return g.op("Sigmoid", self)
 | 
			
		||||
 | 
			
		||||
@ -1031,6 +1033,7 @@ def get_pool_ceil_padding(input, kernel_size, stride, padding):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _max_pool(name, tuple_fn, ndims, return_indices):
 | 
			
		||||
    @symbolic_helper.quantized_args(True, False, False, False, False, False)
 | 
			
		||||
    @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
 | 
			
		||||
    def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
 | 
			
		||||
        if set(tuple_fn(dilation)) != {1}:
 | 
			
		||||
@ -1117,15 +1120,16 @@ max_pool3d_with_indices = _max_pool(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _avg_pool(name, tuple_fn):
 | 
			
		||||
    @symbolic_helper.quantized_args(True)
 | 
			
		||||
    @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
 | 
			
		||||
    def symbolic_fn(
 | 
			
		||||
        g,
 | 
			
		||||
        input,
 | 
			
		||||
        kernel_size,
 | 
			
		||||
        stride,
 | 
			
		||||
        padding,
 | 
			
		||||
        ceil_mode,
 | 
			
		||||
        count_include_pad,
 | 
			
		||||
        input: _C.Value,
 | 
			
		||||
        kernel_size: Tuple[int, ...],
 | 
			
		||||
        stride: Tuple[int, ...],
 | 
			
		||||
        padding: Union[int, Tuple[int, ...]],
 | 
			
		||||
        ceil_mode: int,
 | 
			
		||||
        count_include_pad: int,
 | 
			
		||||
        divisor_override=None,
 | 
			
		||||
    ):
 | 
			
		||||
        if not stride:
 | 
			
		||||
@ -1133,8 +1137,7 @@ def _avg_pool(name, tuple_fn):
 | 
			
		||||
        padding = symbolic_helper._avgpool_helper(
 | 
			
		||||
            tuple_fn, padding, kernel_size, stride, divisor_override, name
 | 
			
		||||
        )
 | 
			
		||||
        if ceil_mode:
 | 
			
		||||
            padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
 | 
			
		||||
        adjusted_padding = padding
 | 
			
		||||
        if count_include_pad:
 | 
			
		||||
            input = g.op(
 | 
			
		||||
                "Pad",
 | 
			
		||||
@ -1143,17 +1146,20 @@ def _avg_pool(name, tuple_fn):
 | 
			
		||||
                mode_s="constant",
 | 
			
		||||
                value_f=0.0,
 | 
			
		||||
            )
 | 
			
		||||
            padding = (0,) * len(padding)
 | 
			
		||||
            adjusted_padding = (0,) * len(padding)
 | 
			
		||||
        if ceil_mode:
 | 
			
		||||
            padding = padding + tuple(a + b for (a, b) in zip(padding_ceil, padding))
 | 
			
		||||
            padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
 | 
			
		||||
            adjusted_padding = adjusted_padding + tuple(
 | 
			
		||||
                a + b for (a, b) in zip(padding_ceil, adjusted_padding)
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            padding = padding * 2
 | 
			
		||||
            adjusted_padding = adjusted_padding * 2
 | 
			
		||||
        output = g.op(
 | 
			
		||||
            "AveragePool",
 | 
			
		||||
            input,
 | 
			
		||||
            kernel_shape_i=tuple_fn(kernel_size),
 | 
			
		||||
            strides_i=tuple_fn(stride),
 | 
			
		||||
            pads_i=padding,
 | 
			
		||||
            pads_i=adjusted_padding,
 | 
			
		||||
        )
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
@ -2510,7 +2516,8 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False):
 | 
			
		||||
    dtype = symbolic_helper._get_const(dtype, "i", "dtype")
 | 
			
		||||
    if symbolic_helper._is_packed_list(data):
 | 
			
		||||
        if dtype is None:
 | 
			
		||||
            dtype = symbolic_helper._unpack_list(data)[0].type().scalarType()
 | 
			
		||||
            dtype = symbolic_helper._unpack_list(data)[0].type().scalarType()  # type: ignore[attr-defined]
 | 
			
		||||
            # TODO(justinchuby): Remove type ignore after #81112 is checked in.
 | 
			
		||||
            dtype = symbolic_helper.scalar_type_to_onnx.index(
 | 
			
		||||
                symbolic_helper.cast_pytorch_to_onnx[dtype]
 | 
			
		||||
            )
 | 
			
		||||
@ -4963,7 +4970,12 @@ class Prim:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def ListUnpack(g, *inputs, **kwargs):
 | 
			
		||||
    def ListUnpack(g, *inputs, **kwargs) -> Optional[List[_C.Value]]:
 | 
			
		||||
        if len(inputs) == 1 and inputs[0].node().kind() == "prim::ListConstruct":
 | 
			
		||||
            # Cancel the previous node if it is ListConstruct by returning its inputs
 | 
			
		||||
            # TODO(justinchuby): Use a public method in the helper module
 | 
			
		||||
            return symbolic_helper._unpack_list(inputs[0])
 | 
			
		||||
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
 | 
			
		||||
@ -12192,6 +12192,8 @@ op_db: List[OpInfo] = [
 | 
			
		||||
                            'TestCommon', 'test_noncontiguous_samples',
 | 
			
		||||
                            device_type='cpu'), ],
 | 
			
		||||
           skips=(
 | 
			
		||||
               # Strides are not the same!
 | 
			
		||||
               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
 | 
			
		||||
               # https://github.com/pytorch/pytorch/issues/67470
 | 
			
		||||
               DecorateInfo(unittest.skip("67470!"),
 | 
			
		||||
                            'TestCommon', 'test_noncontiguous_samples',
 | 
			
		||||
@ -13517,6 +13519,8 @@ op_db: List[OpInfo] = [
 | 
			
		||||
               DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'),
 | 
			
		||||
           ),
 | 
			
		||||
           decorators=(
 | 
			
		||||
               # Strides are not the same!
 | 
			
		||||
               DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
 | 
			
		||||
               DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-02, rtol=1e-02)}),
 | 
			
		||||
                            'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
 | 
			
		||||
           )),
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user