Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Fixes https://github.com/pytorch/pytorch/issues/118129
Suppressions automatically added with
```
import re
with open("error_file.txt", "r") as f:
errors = f.readlines()
error_lines = {}
for error in errors:
match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
if match:
file_path, line_number, error_type = match.groups()
if file_path not in error_lines:
error_lines[file_path] = {}
error_lines[file_path][int(line_number)] = error_type
for file_path, lines in error_lines.items():
with open(file_path, "r") as f:
code = f.readlines()
for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
code[line_number - 1] = code[line_number - 1].rstrip() + f" # type: ignore[{error_type}]\n"
with open(file_path, "w") as f:
f.writelines(code)
```
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
Needs https://github.com/microsoft/onnxscript/pull/721
The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.
This PR utilizes newly introduced [ONNX OpSchema](https://github.com/microsoft/onnxscript/pull/626) to match the input arguments of an ATen operator to find the correct overload.
### OnnxRegistry
Heavily reference on [TorchScript Registry](https://github.com/pytorch/pytorch/pull/84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.
* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
- How to allow users to remove/override specific overload? Using OpSchema to differentiate?
- User registers a new overload with the same OpSchema as one of registered overload.
### OnnxDispatcher
Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.
* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports: the follow-up PRs should address:
- How to add op.Cast with autocast? In torchlib or converter?
- The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.
### OpSchemaWrapper - Matching Score Mechanism
#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.
1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.
#### Example of overloads
1. Different types: Caused by the difference between the ONNX spec and PyTorch.
The matching system finds the correct one.
```python
@torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
...
@torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
...
```
2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"
```python
@torch_op("aten::argmax", trace_only=True)
def aten_argmax(
self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
...
@torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
...
```
This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.
3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.
```python
@torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
...
@torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
...
```
Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.
4. `None` and `[]` and `NoneType` are considered failing the match.
5. Two functions have the same score is recorded into SARIFs.
### TODOs
1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in https://github.com/microsoft/onnxscript/issues/563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.
Co-authored-by: Justin Chu <justinchuby@microsoft.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100660
Approved by: https://github.com/thiagocrepaldi
Continuation after https://github.com/pytorch/pytorch/pull/90163.
Here is a script I used to find all the non-existing arguments in the docstrings (the script can give false positives in presence of *args/**kwargs or decorators):
_Edit:_
I've realized that the indentation is wrong for the last `break` in the script, so the script only gives output for a function if the first docstring argument is wrong. I'll create a separate PR if I find more issues with corrected script.
``` python
import ast
import os
import docstring_parser
for root, dirs, files in os.walk('.'):
for name in files:
if root.startswith("./.git/") or root.startswith("./third_party/"):
continue
if name.endswith(".py"):
full_name = os.path.join(root, name)
with open(full_name, "r") as source:
tree = ast.parse(source.read())
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
all_node_args = node.args.args
if node.args.vararg is not None:
all_node_args.append(node.args.vararg)
if node.args.kwarg is not None:
all_node_args.append(node.args.kwarg)
if node.args.posonlyargs is not None:
all_node_args.extend(node.args.posonlyargs)
if node.args.kwonlyargs is not None:
all_node_args.extend(node.args.kwonlyargs)
args = [a.arg for a in all_node_args]
docstring = docstring_parser.parse(ast.get_docstring(node))
doc_args = [a.arg_name for a in docstring.params]
clean_doc_args = []
for a in doc_args:
clean_a = ""
for c in a.split()[0]:
if c.isalnum() or c == '_':
clean_a += c
if clean_a:
clean_doc_args.append(clean_a)
doc_args = clean_doc_args
for a in doc_args:
if a not in args:
print(full_name, node.lineno, args, doc_args)
break
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90505
Approved by: https://github.com/malfet, https://github.com/ZainRizvi
Extend `register_custom_op` to support onnx-script local function. The FunctionProto from onnx-script is represented by custom op and inserted into ModelProto for op execution.
NOTE: I did experiments on >2GB case of a simple model with large initializers:
```python
import torch
class Net(torch.nn.Module):
def __init__(self, B, C):
super().__init__()
self.layer_norm = torch.nn.LayerNorm((B, C), eps=1e-3)
def forward(self, x):
return self.layer_norm(x)
N, B, C = 3, 25000, 25000
model = Net(B, C)
x = torch.randn(N, B, C)
torch.onnx.export(model, x, "large_model.onnx", opset_version=12)
```
And it turns out we won't get model_bytes > 2GB after `_export_onnx` pybind cpp function, as we split initializer in external files in that function, and have serialization before return the model bytes, which protobuf is not allowed to be larger than 2GB at any circumstances.
The test cases can be found in the next PR #86907 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86906
Approved by: https://github.com/justinchuby, https://github.com/BowenBao
This PR create the `GraphContext` class and relays all graph methods to _C.Graph as well as implements the `g.op` method. The GraphContext object is passed into the symbolic functions in place of _C.Graph for compatibility with existing symbolic functions.
This way (1) we can type annotate all `g` args because the method is defined and (2) we can use additional context information in symbolic functions. (3) no more monkey patching on `_C.Graph`
Also
- Fix return type of `_jit_pass_fixup_onnx_controlflow_node`
- Create `torchscript.py` to house torch.Graph related functions
- Change `GraphContext.op` to create nodes in the Block instead of the Graph
- Create `add_op_with_blocks` to handle scenarios where we need to directly manipulate sub-blocks. Update loop and if symbolic functions to use this function.
## Discussion
Should we put all the context inside `SymbolicContext` and make it an attribute in the `GraphContext` class? This way we only define two attributes `GraphContext.graph` and `GraphContext.context`. Currently all context attributes are directly defined in the class.
### Decision
Keep GraphContext flatand note that it will change in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84728
Approved by: https://github.com/AllenTiTaiWang, https://github.com/BowenBao