The problem:
- The new CustomOp API depends on torchgen.model
- torchgen.model imports `yaml`
- `yaml` is not a PyTorch runtime dependency
To unblock myself, because I'm not sure how long it'll take to
convince people yaml should be a PyTorch runtime dependency
(unless one of you wants to approve #100166), this PR removes the
yaml dependency from torchgen.model.
It does so by splitting torchgen.utils (the offender) into
torchgen.utils (no yaml) and torchgen.yaml (which uses yaml).
Test Plan:
- CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100203
Approved by: https://github.com/ezyang, https://github.com/Skylion007
Preferring dash over underscore in command-line options. Add `--command-arg-name` to the argument parser. The old arguments with underscores `--command_arg_name` are kept for backward compatibility.
Both dashes and underscores are used in the PyTorch codebase. Some argument parsers only have dashes or only have underscores in arguments. For example, the `torchrun` utility for distributed training only accepts underscore arguments (e.g., `--master_port`). The dashes are more common in other command-line tools. And it looks to be the default choice in the Python standard library:
`argparse.BooleanOptionalAction`: 4a9dff0e5a/Lib/argparse.py (L893-L895)
```python
class BooleanOptionalAction(Action):
def __init__(...):
if option_string.startswith('--'):
option_string = '--no-' + option_string[2:]
_option_strings.append(option_string)
```
It adds `--no-argname`, not `--no_argname`. Also typing `_` need to press the shift or the caps-lock key than `-`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94505
Approved by: https://github.com/ezyang, https://github.com/seemethere
As part of the ongoing LTC migration effort, PyTorch/XLA is updating its codegen to use `xla::Shape` instead of `torch::lazy::Shape`. To achieve this, this PR updates the codegen to make the `GenLazyNativeFuncDefinition` generator customizable.
The existing `GenLazyNativeFuncDefinition` is kept by using the initial default values, so this change should not introduce any new behaviors to the existing codegen in PyTorch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87823
Approved by: https://github.com/alanwaketan, https://github.com/wconstab
Partially fixes: #66328
This PR:
- adds support for `ITensorList` to the dispatcher for:
- computing the dispatch key
- boxing and unboxing `ITensorList`
- modified the codegen for structured kernels:
- codegen APIs use `ITensorList` instead of `ArrayRef<Tensor>`
**Changes summary:**
- Signature changes due to the different APIs:
- dispatcher API (e.g. `BatchingRegistrations.cpp`)
- C++ API (e.g. `TensorShape.cpp`)
- Miscelaneous functions used by codegen'd functions (e.g. `FunctionalTensorWrapper.*`)
- Dispatcher changes for handling `ITensorList` correctly (e.g. `DispatchKeyExtractor.h`)
- Signature changes of `at::cat` due to the need of `const` inside `TensorBody.h`
- Forward declarations of `ITensorList` (e.g. `MethodOperators.h`)
- Codegen changes, special casing structured kernels (e.g. `gen.py`)
**Short description of structured kernels special casing:**
I introduced, mainly, 5 types of changes to the codegen for generating code depending on
whether the kernel is structured or not:
1. Added a `structured_type_override` flag to the `argument_type` function definition of
the affected APIs (mainly the dispatcher and C++ APIs).
- `api/cpp.py`, `api/dispatcher.py`, `api/native.py`
2. Added a `structured_type_override` member to the signature
classes (e.g. `CppSignature`), since `FunctionSchema` doesn't really know whether the
function is structured or not
- `api/types.py`
3. Added a `part_of_structured_group` to `NativeFunction` class, which is just a
convenient function to forward to `structured_type_override` wherever needed
- `model.py`
4. Appropriately changed the rest of the codegen, whenever it used either the signature
classes or the `arguments` function directly
5. Added a check for `const ITensorList&` type wherever there was a check for `TensorList`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73350
Approved by: https://github.com/bdhirsh
This is to get a conversation started.
* @JackCaoG we could add attributes to items in `ir_codegen` section to customize IR generation logic (e.g. not generating `::Lower`). Though it could be a bit tricky to thread it through.
* Adding an extra argument to `map_codegen` to filter native functions out seems like a step in the right direction. Otherwise, it's a bit confusing how do we go from a full list to a codegen list.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81847
Approved by: https://github.com/JackCaoG, https://github.com/wconstab, https://github.com/bdhirsh
Summary:
Adding a feature to allow user to specify namespaces for operator and kernels.
# Feature
There's a feature request to allow DSL to:
1. take in an operator namespace other than `aten`.
2. take in a kernel that is in a different namespace than `at::native`.
For both features, we only allow user to have a single layer of namespace for the sake of simplicity. If user specify `custom::function` as kernel, the codegen will depend on `custom::native::function` where `native` is hardcoded.
# Proposal
For feature 1, add a `namespace` attribute to data class `NativeFunction`. The namespace will be extract out by matching pattern "::" on the `func` variable. For `NativeFunctionsGroup` there's an assumption that all variants (function, inplace, out) will have the same namespace. By default (if not specified) the namespace will be "aten".
For feature 2, add a `namespace` attribute to `BackendMetadata` class, similarly match pattern "::" on the kernel field. Remove the `cpp_namespace` field from `register_dispatch_key` data class. By default (if not specified) the namespace for a kernel would be "at::native".
Test Plan:
Example yaml entries:
```
- func: custom::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU: custom::gelu_out_cpu
CUDA: custom::gelu_out_cuda
MPS: custom::gelu_out_mps
- func: custom::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)
structured_delegate: gelu.out
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
NestedTensorCPU, NestedTensorCUDA: custom::NestedTensor_gelu_
- func: custom::gelu(Tensor self, *, str approximate='none') -> Tensor
structured_delegate: gelu.out
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
MkldnnCPU: custom::mkldnn_gelu
QuantizedCPU: custom::gelu_quantized_cpu
NestedTensorCPU, NestedTensorCUDA: custom::NestedTensor_gelu
```
see generated code:
`RegisterCPU.cpp`:
```
TORCH_LIBRARY_IMPL(aten, CPU, m) {
...
}
TORCH_LIBRARY_IMPL(custom, CPU, m) {
m.impl("gelu", TORCH_FN(wrapper_gelu));
m.impl("gelu.out", TORCH_FN(wrapper_gelu_out_out));
m.impl("gelu_", TORCH_FN(wrapper_gelu_));
};
```
```
struct structured_gelu_out_cpu_inplace final : public custom::native::structured_gelu_out_cpu {
structured_gelu_out_cpu_inplace(Tensor& self) : outputs_{std::ref(self)} {}
void set_output_strided(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names
) override {
const auto& out = outputs_[output_idx].get();
check_inplace(out, sizes, options);
auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
if (C10_UNLIKELY(maybe_proxy.has_value())) {
proxy_outputs_[output_idx] = c10::ExclusivelyOwned<Tensor>(std::move(maybe_proxy).value());
}
if (!names.empty()) {
namedinference::propagate_names(outputs_[output_idx], names);
}
// super must happen after, so that downstream can use maybe_get_output
// to retrieve the output
custom::native::structured_gelu_out_cpu::set_output_raw_strided(output_idx, sizes, strides, options, names);
}
void set_output_raw_strided(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names
) override {
const auto& out = outputs_[output_idx].get();
check_inplace(out, sizes, options);
if (!names.empty()) {
namedinference::propagate_names(outputs_[output_idx], names);
}
// super must happen after, so that downstream can use maybe_get_output
// to retrieve the output
custom::native::structured_gelu_out_cpu::set_output_raw_strided(output_idx, sizes, strides, options, names);
}
const Tensor& maybe_get_output(int64_t output_idx) override {
return proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get();
}
std::array<std::reference_wrapper<Tensor>, 1> outputs_;
std::array<c10::optional<c10::ExclusivelyOwned<Tensor>>, 1> proxy_outputs_;
};
```
`RegisterSchema.cpp`
```
TORCH_LIBRARY(aten, m) {
...
}
TORCH_LIBRARY(custom, m) {
m.def("gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)");
m.def("gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)");
m.def("gelu(Tensor self, *, str approximate='none') -> Tensor");
};
```
Differential Revision: D36558459
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78015
Approved by: https://github.com/bdhirsh
Add codegen infrastructure to generate IR nodes for non-native ops.
The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g.
```
non_native:
...
- func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
...
```
these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`.
Fixes#74628
CC: @wconstab @desertfire @henrytwo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76535
Approved by: https://github.com/wconstab
Summary: Currently OpKind is stored as an object field called op_ for each IR
node, and one usage of op_ is to avoid dynamic_cast in NodeCast when we
need to downcast a base-node pointer into a concrete sub-node pointer.
As a result, we need to construct and pass in an op when downcasting
nodes, and this becomes quite anonnying when we start to implement the
trie-based IR node reusing. More importantly, the op for each subclass
should be unique for that subclass and thus making it a const static field
is a more logical design.
In this PR, we still keep the object-level op_ for easier XLA adoption. As
furture work, we can come back to remove op_, make the op() method
virtual, and get rid of OpKind in all the node constructors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76711
Approved by: https://github.com/wconstab, https://github.com/JackCaoG