140 Commits

Author SHA1 Message Date
e3e33cfae0 Enable codegen of per-dispatch key derivative formulas in derivatives.yaml (#82801)
`derivatives.yaml` can now take a `dispatch` entry which registers per-autograd dispatch key derivatives such as
```
name: foo(Tensor self, Tensor y) -> Tensor
dispatch:
  Default:
    x: grad
    y: grad.expand(y.sizes())
  AutogradNestedTensor:
    x: grad
    y:  NestedTensor_foo_backward(grad, y)
output_differentiabilty: [True]
```

However the old schema where there is no `dispatch` entry is still supported.

Would greatly appreciate feedback on *how to improve the testing strategy* of this PR, currently have registered an aten test op in TestOps.cpp with dummy gradients in derivatives.yaml and have some tests in test_autograd.py:TestAutogradMultipleDispatch but I am not sure whether these are sufficiently rigorous.

Additionally, this PR also makes the assumption that sets like [VIEW_FUNCTIONS](ff5399e528/tools/autograd/gen_inplace_or_view_type.py (L60)) are per-native-function and not per-native-function-and-dispatch-key. I'm not sure whether this is necessarily the case, *would there ever be a situation where (e.g. a nested_tensor op is a view op but the aten function is not or vice versa?)*

* __->__ #82801
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82801
Approved by: https://github.com/bhosmer, https://github.com/albanD
2022-08-10 19:26:29 +00:00
943553965e support custom class in torchgen schema parser (#82925)
Differential Revision: [D38480514](https://our.internmc.facebook.com/intern/diff/D38480514/)

torchgen schema parser does not support parsing function schemas using custom class so far. Here is an example:
```
quantized::conv2d_relu.new(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> (Tensor)
```

This PR parse custom class name and encapsulate that into an object of CustomClassType. The only thing we need right now is just store the string class name and return that in `__str__` method.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82925
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
2022-08-08 22:24:43 +00:00
4f255dbfb3 Remove manual bindings for arange (#81380)
The functional variant of one of the `arange` overloads has a schema mismatch with the out variant. The functional one has `Scalar step`, but the corresponding out variant has `Scalar step=1`. This isn't allowed, so it had to be special-cased in the python codegen and manually bound. This adds the default `step` value to the functional overload and removes the special-casing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81380
Approved by: https://github.com/ngimel
2022-08-07 00:10:27 +00:00
684ce1b0bc add inplace_view tag to resize_() (#82667)
`resize_()` is annoying because it needs special casing for functionalization. It's technically an inplace-view op, but it can't really have a pure view variant, since calling resize_() might bust the old storage. I gave it an `inplace_view` tag so that stuff like `FakeTensor` that relies on tags will pick it up properly, which required  jumping through some codegen hoops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82667
Approved by: https://github.com/eellison
2022-08-03 18:13:00 +00:00
ba84e9662e Use OrderedSet in ufunc codegen (#82567)
Follow up from https://github.com/pytorch/pytorch/pull/82536#discussion_r934000916
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82567
Approved by: https://github.com/ezyang
2022-08-01 18:11:06 +00:00
301fe8c27d [torchgen] Fix multiple backends with custom namespace (#82133)
Summary:
Some quantized operators needs `QuantizedCPU` backend, due to an issue in namespace checking, currently if we have two backends as well as a custom namespaces in native function, codegen will hit assertion error. This PR fixes this issue

The root cause is that codegen right now asserts that a native function should only have one namespace. The current behavior is that If a native function is not found in a `BackendIndex`, we will use default namespace for that backend, for fallback kernels. However that default namespace may not be listed in the yaml file and it should not be counted when checking if we have two different namespaces for that backend. In our error case, we have 2 `BackendIndex`, one for `QuantizedCPU` and one for `CPU`. The native function doesn't have a kernel in `QuantizedCPU` but we still use a default namespace (`at::native`) for it. Since we have a custom namespace for dispatch key `CPU`, we ran into the assertion error.

This PR changes the assertion criteria. We only error out if a namespace has two or more kernels and they have two or more different namespaces.

Test Plan: rely on newly added unit test

Differential Revision: D38101345

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82133
Approved by: https://github.com/iseeyuan
2022-07-29 22:53:58 +00:00
2f95b61cea Revert "Revert "Make factory functions CompositeExplicitAutograd (#82251)"" (#82470)
This reverts commit 1df307f3343085697b4086336fe8936d108e277e.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82470
Approved by: https://github.com/zou3519
2022-07-29 17:06:07 +00:00
1df307f334 Revert "Make factory functions CompositeExplicitAutograd (#82251)"
This reverts commit 9943ca3ce67e06ca0e28892eed72d6cc4666351c.

Reverted https://github.com/pytorch/pytorch/pull/82251 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally
2022-07-29 03:05:59 +00:00
9943ca3ce6 Make factory functions CompositeExplicitAutograd (#82251)
This also makes them not decompose when we switch Python key.
Note that CompositeExplicitAutogradNonFunctional maybe be overly
conservative for some implementations (which actually call into
other functional ops), but for now I just uniformly apply this
everywhere to avoid errors.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82251
Approved by: https://github.com/bdhirsh, https://github.com/zou3519
2022-07-28 18:18:51 +00:00
d38ffa6a4c Make all of new_/_like factory functions composite explicit autograd (#82238)
Once CompositeImplicitAutograd gets registered to Python key, this will
ensure that tensor subclasses can interpose on these functions directly
rather than getting decomposed.  We prefer not decomposing as these
functions are functional, but their implementations use inplace
operations (and are thus more difficult to deal with, unless you use
functionalization.)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82238
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
2022-07-27 18:33:46 +00:00
d2c47d559c Revert "Revert "Enabling SymInt in autograd; take 3 (#81145)"" ; make sure is_intlist checks for symintnodes (#82189)
### Description
<!-- What did you change and why was it needed? -->

### Issue
<!-- Link to Issue ticket or RFP -->

### Testing
<!-- How did you test your change? -->

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82189
Approved by: https://github.com/ezyang
2022-07-26 20:47:11 +00:00
c078476eb0 Revert "Enabling SymInt in autograd; take 3 (#81145)"
This reverts commit 032facd6e6020a86556a1e8c8e6e1b414c9d14d6.

Reverted https://github.com/pytorch/pytorch/pull/81145 on behalf of https://github.com/jeanschmidt due to breaking internal builds
2022-07-22 11:15:20 +00:00
032facd6e6 Enabling SymInt in autograd; take 3 (#81145)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81145
Approved by: https://github.com/ezyang
2022-07-22 00:14:50 +00:00
6f0c253956 Add sparse, quantized and nested tensor meta support to codegen (#81793)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81793
Approved by: https://github.com/cpuhrsch, https://github.com/bdhirsh
2022-07-21 21:23:56 +00:00
ec91f42d79 Sync up DispatchKey in model with C++ (#81770)
I also dropped some keys that are not useful for codegen.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81770
Approved by: https://github.com/bdhirsh
2022-07-21 21:23:54 +00:00
8f07b7a069 Fix circular import error in torchgen (#81355)
This also formats `tools/pyi/gen_pyi.py` with `usort` to test the fix because that is how the bug was discovered. The
usort-formatted `gen_pyi.py` should work now without any issues

Fixes #81294

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81355
Approved by: https://github.com/ezyang
2022-07-13 03:16:38 +00:00
5c8a9803c8 [torchgen] Support multiple namespace in NativeFunctions.h (#79733)
Summary:
This is a follow up to #78015. This PR
* introduces namespace logic for generating `NativeFunctions.h`.
* adds helper function to extract namespace from string
* relaxes the constraint on the levels we support for custom kernel namespace to 2

Test Plan:
Yaml entry:
```
- func: unsqueeze.out(Tensor(a) self, int dim, *, Tensor(a!) out) -> Tensor(a!)
  variants: function
  device_check: NoCheck
  dispatch:
    CPU: custom_1::custom_2::unsqueeze
```

Generated `NativeFunctions.h`:

```
namespace custom_1 {
namespace custom_2 {
namespace native {
    TORCH_API at::Tensor & unsqueeze(const at::Tensor & self, int64_t dim, at::Tensor & out);
} // namespace native
} // namespace custom_2
} // namespace custom_1

```

Differential Revision: D37198111

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79733
Approved by: https://github.com/bdhirsh
2022-07-08 21:56:52 +00:00
960758b0b7 fix overload ambiguity with functional ops; fix _foreach op grouping (#80556)
This should fix the last issue that @anijain2305 hit when running ResNet with TorchDynamo <> functionalization.

Today if you try to call an `OpOverloadPacket` from python with some arguments, we will use the types of those arguments to perform overload resolution. With some functional variants of ops, this can be ambiguous.

Today this affects just one op: `_fused_moving_avg_obs_fq_helper`, although it would potentially affect e.g. `native_batch_norm` in the future.

Example:
```
# There are technically two overloads:
# torch.ops.aten._fused_moving_avg_obs_fq_helper.default (returns 2 argument, mutates 4 of its inputs inplace)
# torch.ops.aten._fused_moving_avg_obs_fq_helper.functional (returns 6 argument, mutates none of its inputs)

# We pick the wrong one - no way to know that we should pick the functional one, just from the call site.
outs = torch.ops.aten._fused_moving_avg_obs_fq_helper(a, a, a, a, a, a, a, 1.0, 0, 1, 0)
# raises an error - tries to call the overload with only 2 returns
return _fused_moving_avg_obs_fq_helper_functional[5]
```

Specifically, functionalization will bake `_fused_moving_avg_obs_fq_helper.functional` into the graph, but when AOTAutograd tries to compile with TorchScript, it needs to remove the overload name (TS doesn't know how to parse overload names directly, so we need to remove the overload name and let it infer the right overload at runtime later- so it picks the wrong one).

The situation is pretty similar to inplace; `ops.aten.add` and `ops.aten.add_` represent two different `OverloadPacket` objects; they can't be overloads of the same op, because their schemas would be ambiguous - the alias annotations are different, but that isn't enough to disambiguate).

In this PR, I try to fix the situation in a pretty similar way to how we handle `inplace` in the data model: `inplace` ops get their own base operator name, but they are represented as a flag inside of `BaseOperatorName` in the data model.

Two other important changes that I made as part of this PR:

(1) Originally, there were ~100 different `*_functional` operators: e.g. we had operators named `resize.functional` and `zero.functional`. The `_functional` bit isn't actually necessary in most cases: it's only necessary for operators that **also** have a `SchemaKind.mutable` variant, where `_fused_moving_avg_obs_fq_helper` is the only op that fits that description today. So I removed the unnecessary notion of "functional" from those other ops. I also added a bunch of assertions to force this restriction.

I think that makes more sense in the long run, because it eliminates an unnecessary difference in the model. E.g. we don't have `add_.Tensor` and `add.Tensor_functional`. We just have `add_.Tensor` and `add.Tensor`.

(2) I noticed that we actually still weren't pairing up a bunch of `_foreach` operators correctly, because their input arguments were different (`self` vs. `tensors`). Since they're private API's, I went ahead and changed the argument names directly so they get matched up. Before this PR, we were generating a separate `_foreach_add` and `_foreach_add.functional` variant in a bunch of cases, that really did the same thing (but happened to have a different name for the first argument).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/80556
Approved by: https://github.com/ezyang, https://github.com/albanD
2022-07-06 12:45:11 +00:00
efc7343743 Revert "Revert "Put symint overloads on a different name"" (#79680)
This relands https://github.com/pytorch/pytorch/pull/79281

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79680
Approved by: https://github.com/malfet
2022-06-21 07:06:33 +00:00
0d909d3cff add a new FunctionSchema kind called scratch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79659

The scratch op expose intermetidate/scratch tensors used in kernel implementation
as kernel input arguments so a memory planning algorithm can plan memory for those
tensors.

Differential Revision: [D37194429](https://our.internmc.facebook.com/intern/diff/D37194429/)

Approved by: https://github.com/bdhirsh
2022-06-20 16:34:16 +00:00
adf8060600 add a new alias key for functional to view op decompositions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79615

Approved by: https://github.com/zou3519
2022-06-15 23:18:09 +00:00
b9bb52d97b Revert "Put symint overloads on a different name"
This reverts commit 213a8fc9927e2c90bd28a37f714950e4aa367ca6.

Reverted https://github.com/pytorch/pytorch/pull/79281 on behalf of https://github.com/bigfootjon due to Diff reverted internally
2022-06-15 17:15:21 +00:00
213a8fc992 Put symint overloads on a different name
Due to implicit conversion shenanigans, having both IntArrayRef
and SymIntArrayRef overloads makes {} ambiguous.  While we could
fix this by making a single unified type that accepts all the overloads
we want, an easier fix was to just push the SymIntArrayRef overload
to its own name.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79281

Approved by: https://github.com/suo
2022-06-12 14:36:39 +00:00
24050a5801 [RFC][Codegen] Add custom namespace support (#78015)
Summary:
Adding a feature to allow user to specify namespaces for operator and kernels.

# Feature
There's a feature request to allow DSL to:
1. take in an operator namespace other than `aten`.
2. take in a kernel that is in a different namespace than `at::native`.

For both features, we only allow user to have a single layer of namespace for the sake of simplicity. If user specify `custom::function` as kernel, the codegen will depend on `custom::native::function` where `native` is hardcoded.

# Proposal

For feature 1, add a `namespace` attribute to data class `NativeFunction`. The namespace will be extract out by matching pattern "::" on the `func` variable. For `NativeFunctionsGroup` there's an assumption that all variants (function, inplace, out) will have the same namespace. By default (if not specified) the namespace will be "aten".

For feature 2, add a `namespace` attribute to `BackendMetadata` class, similarly match pattern "::" on the kernel field. Remove the `cpp_namespace` field from `register_dispatch_key` data class. By default (if not specified) the namespace for a kernel would be "at::native".

Test Plan:
Example yaml entries:
```
- func: custom::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)
  structured: True
  structured_inherits: TensorIteratorBase
  device_check: NoCheck   # TensorIterator
  python_module: nn
  dispatch:
    CPU: custom::gelu_out_cpu
    CUDA: custom::gelu_out_cuda
    MPS: custom::gelu_out_mps

- func: custom::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)
  structured_delegate: gelu.out
  device_check: NoCheck   # TensorIterator
  python_module: nn
  dispatch:
    NestedTensorCPU, NestedTensorCUDA: custom::NestedTensor_gelu_

- func: custom::gelu(Tensor self, *, str approximate='none') -> Tensor
  structured_delegate: gelu.out
  device_check: NoCheck   # TensorIterator
  python_module: nn
  dispatch:
    MkldnnCPU: custom::mkldnn_gelu
    QuantizedCPU: custom::gelu_quantized_cpu
    NestedTensorCPU, NestedTensorCUDA: custom::NestedTensor_gelu
```

see generated code:

`RegisterCPU.cpp`:
```
TORCH_LIBRARY_IMPL(aten, CPU, m) {
  ...
}
TORCH_LIBRARY_IMPL(custom, CPU, m) {
    m.impl("gelu", TORCH_FN(wrapper_gelu));
    m.impl("gelu.out", TORCH_FN(wrapper_gelu_out_out));
    m.impl("gelu_", TORCH_FN(wrapper_gelu_));
};
```
```
struct structured_gelu_out_cpu_inplace final : public custom::native::structured_gelu_out_cpu {
    structured_gelu_out_cpu_inplace(Tensor& self) : outputs_{std::ref(self)} {}

    void set_output_strided(
        int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
        TensorOptions options, DimnameList names
    ) override {

        const auto& out = outputs_[output_idx].get();
        check_inplace(out, sizes, options);

        auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
        if (C10_UNLIKELY(maybe_proxy.has_value())) {
            proxy_outputs_[output_idx] = c10::ExclusivelyOwned<Tensor>(std::move(maybe_proxy).value());
        }

        if (!names.empty()) {
          namedinference::propagate_names(outputs_[output_idx], names);
        }
        // super must happen after, so that downstream can use maybe_get_output
        // to retrieve the output
        custom::native::structured_gelu_out_cpu::set_output_raw_strided(output_idx, sizes, strides, options, names);
    }

    void set_output_raw_strided(
        int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
        TensorOptions options, DimnameList names
    ) override {

        const auto& out = outputs_[output_idx].get();
        check_inplace(out, sizes, options);

        if (!names.empty()) {
          namedinference::propagate_names(outputs_[output_idx], names);
        }
        // super must happen after, so that downstream can use maybe_get_output
        // to retrieve the output
        custom::native::structured_gelu_out_cpu::set_output_raw_strided(output_idx, sizes, strides, options, names);
    }

    const Tensor& maybe_get_output(int64_t output_idx) override {
      return proxy_outputs_[output_idx].has_value() ? **proxy_outputs_[output_idx] : outputs_[output_idx].get();

    }
    std::array<std::reference_wrapper<Tensor>, 1> outputs_;
    std::array<c10::optional<c10::ExclusivelyOwned<Tensor>>, 1> proxy_outputs_;
};
```

`RegisterSchema.cpp`
```
TORCH_LIBRARY(aten, m) {
  ...
}
TORCH_LIBRARY(custom, m) {
    m.def("gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)");

    m.def("gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)");

    m.def("gelu(Tensor self, *, str approximate='none') -> Tensor");
};
```

Differential Revision: D36558459

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78015
Approved by: https://github.com/bdhirsh
2022-06-10 21:04:36 +00:00
7b3a0ff87a Port index.Tensor to structured kernels.
Tracking issue: #55070

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69607

Approved by: https://github.com/bdhirsh
2022-06-10 17:27:47 +00:00
4b82ef7928 Revert "Port index.Tensor to structured kernels."
This reverts commit cfd84125bdb841f0efe038988bb87d645f419338.

Reverted https://github.com/pytorch/pytorch/pull/69607 on behalf of https://github.com/zengk95 due to This is breaking mac trunk tests cfd84125bd
2022-06-08 20:16:10 +00:00
cfd84125bd Port index.Tensor to structured kernels.
Tracking issue: #55070

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69607

Approved by: https://github.com/bdhirsh
2022-06-08 18:17:52 +00:00
bcb424c8cf Fix #78675
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78699

Approved by: https://github.com/tugsbayasgalan
2022-06-04 01:07:24 +00:00
fca1f495c2 Revert "Port index.Tensor to structured kernels."
This reverts commit 9fe6f1baf5d1c130e717fc32d993bb531ed05b62.

Reverted https://github.com/pytorch/pytorch/pull/69607 on behalf of https://github.com/suo due to this broke master, see: 9fe6f1baf5
2022-06-01 00:12:15 +00:00
9fe6f1baf5 Port index.Tensor to structured kernels.
Tracking issue: #55070

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69607

Approved by: https://github.com/bdhirsh
2022-05-31 22:15:20 +00:00
02c4d877b4 Codegen Non-Native IR Nodes (#76535)
Add codegen infrastructure to generate IR nodes for non-native ops.

The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g.
```
non_native:
    ...
    - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor
    ...
```
these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`.

Fixes #74628

CC: @wconstab @desertfire @henrytwo

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76535
Approved by: https://github.com/wconstab
2022-05-24 19:29:23 +00:00
9e806619cc [Codegen] Remove view operator check in NativeFunctionGroups and allow skipping native function generation (#78145)
Summary:
This PR adds two features:
* A boolean to turn off native function generation in codegen
* Relaxing `view` operator check for `NativeFunctionGroups`

Differential Revision: D36604646

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78145
Approved by: https://github.com/iseeyuan, https://github.com/bdhirsh
2022-05-24 05:48:30 +00:00
0161e9eb00 [test] attempt to functionalize ops with mutable positional-only args
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76320

Approved by: https://github.com/ezyang
2022-05-19 18:50:34 +00:00
befa4e371e Fix typo
Fixes #77412

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77488

Approved by: https://github.com/mruberry
2022-05-18 18:25:54 +00:00
f348b1b2b5 Add the Runtime components for MPS backend. (#76725)
The PR adds the runtime components and few basic operations like copy, as_strided for MPS backend.

Current list of identified TODOs are:

-  https://github.com/pytorch/pytorch/issues/77176
- Unify the logic with CUDACachingAllocator and remove redundant code.
-  https://github.com/pytorch/pytorch/issues/77170
- Look into using C++ smart pointers where possible with ObjC code
- Use empty_strided_generic() to implement the `empty_strided_mps` code
- https://github.com/pytorch/pytorch/issues/77144
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76725
Approved by: https://github.com/albanD
2022-05-11 17:19:45 +00:00
6df5a53127 Fix unmarked fstring
This error message currently prints the format string literally, because the string isn't marked with the `f`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76841
Approved by: https://github.com/bdhirsh
2022-05-04 21:18:05 +00:00
b204ad863f Revert "Revert "Allow specifying tags for aten operators in native_functions.yaml""
This reverts commit ea44645c9a682a4e212e64b94a86383c3388ed6b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76456

Approved by: https://github.com/osalpekar
2022-04-28 02:04:57 +00:00
74e93f727a remove _is_foreach_op codegen special cases, clean up mutable return type checks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76190

Approved by: https://github.com/ezyang
2022-04-25 21:34:17 +00:00
7d44b3675b functionalization: add support for zero_()
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75913

Approved by: https://github.com/albanD
2022-04-25 21:31:48 +00:00
36420b5e8c Rename tools/codegen to torchgen (#76275)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76275

In preparation for addressing
https://github.com/pytorch/pytorch/issues/73212

Diff was generated with:

```
git mv tools/codegen torchgen
git grep -l 'tools.codegen' | xargs sed -i 's/tools.codegen/torchgen/g'
sed -i "s/\${TOOLS_PATH}\/codegen/\${TORCH_ROOT}\/torchgen/g" caffe2/CMakeLists.txt
```

and a manual edits to:

* tools/test/test_gen_backend_stubs.py
* torchgen/build.bzl
* torchgen/gen_backend_stubs.py

aka this diff:

```
 diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py
index 3dc26c6d2d..104054575e 100644
 --- a/tools/test/test_gen_backend_stubs.py
+++ b/tools/test/test_gen_backend_stubs.py
@@ -9,7 +9,7 @@ from torchgen.gen_backend_stubs import run
 from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE  # noqa: F401

 path = os.path.dirname(os.path.realpath(__file__))
-gen_backend_stubs_path = os.path.join(path, '../torchgen/gen_backend_stubs.py')
+gen_backend_stubs_path = os.path.join(path, '../../torchgen/gen_backend_stubs.py')

 # gen_backend_stubs.py is an integration point that is called directly by external backends.
 # The tests here are to confirm that badly formed inputs result in reasonable error messages.
 diff --git a/torchgen/build.bzl b/torchgen/build.bzl
index ed04e35a43..d00078a3cf 100644
 --- a/torchgen/build.bzl
+++ b/torchgen/build.bzl
@@ -1,6 +1,6 @@
 def define_targets(rules):
     rules.py_library(
-        name = "codegen",
+        name = "torchgen",
         srcs = rules.glob(["**/*.py"]),
         deps = [
             rules.requirement("PyYAML"),
@@ -11,6 +11,6 @@ def define_targets(rules):

     rules.py_binary(
         name = "gen",
-        srcs = [":codegen"],
+        srcs = [":torchgen"],
         visibility = ["//visibility:public"],
     )
 diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py
index c1a672a655..beee7a15e0 100644
 --- a/torchgen/gen_backend_stubs.py
+++ b/torchgen/gen_backend_stubs.py
@@ -474,7 +474,7 @@ def run(
 ) -> None:

     # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
-    pytorch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
+    pytorch_root = pathlib.Path(__file__).parent.parent.absolute()
     template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")

     def make_file_manager(install_dir: str) -> FileManager:
```

run_all_fbandroid_tests

Test Plan: sandcastle

Reviewed By: albanD, ngimel

Differential Revision: D35770317

fbshipit-source-id: 153ac4a7fef15b1e750812a90bfafdbc8f1ebcdf
(cherry picked from commit c6d485d1d4648fa1c8a4c14c5bf3d8e899b9b4dd)
2022-04-25 01:38:06 +00:00