49 Commits

Author SHA1 Message Date
35c4130fd1 [2/N] Fix ruff warnings (#164460)
Apply ruff `SIM` rules.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164460
Approved by: https://github.com/ezyang
2025-10-04 03:40:32 +00:00
5f18f240de Add initial suppressions for pyrefly (#164177)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
`python3 scripts/lintrunner.py`
`pyrefly check`

---

Pyrefly check before: https://gist.github.com/maggiemoss/3a0aa0b6cdda0e449cd5743d5fce2c60
After:

```
 INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml`
 INFO 0 errors (1,063 ignored)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164177
Approved by: https://github.com/Lucaskabela
2025-10-02 20:57:41 +00:00
beb4d7816d [BE]: ruff PLC0207 - use maxsplit kwarg (#160107)
Automatically replaces split with rsplit when relevant and only performs the split up to the first ( or last value). This allows early return of the split function and improve efficiency.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160107
Approved by: https://github.com/albanD
2025-08-08 03:14:59 +00:00
86b1116f22 pyfmt lint torch/_custom_op/* (#155782)
file torch/_custom_op/functional.py does not exisits
file torch/_custom_op/__init__.py is empty.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155782
Approved by: https://github.com/Skylion007
2025-06-12 23:04:11 +00:00
f2cfe8b59f PEP585 update - mostly toplevels (#145178)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178
Approved by: https://github.com/bobrenjc93
2025-01-22 02:21:14 +00:00
c434a64f31 Delete torch._library.register_functional_op (#145110)
Fixes #117816, #117834, #117871

This has been superceded by auto_functionalized_v2. There are no
internal usages and this is private API so it is safe to delete.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145110
Approved by: https://github.com/williamwen42
ghstack dependencies: #145109
2025-01-18 00:58:25 +00:00
e393c7fa05 Tighten torch.library.infer_schema input types (#130705)
Made the following changes:
- mutates_args is now keyword-only and mandatory. This is to align with
  torch.library.custom_op (which makes it mandatory because it's easy to
  miss)
- op_name is now keyword-only. This helps the readability of the API
- updated all usages of infer_schema

This change is not BC-breaking because we introduced
torch.library.infer_schema a couple of days ago.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130705
Approved by: https://github.com/yushangdi
ghstack dependencies: #131777
2024-07-29 16:01:19 +00:00
68a4f2a3df Revert "Tighten torch.library.infer_schema input types (#130705)"
This reverts commit ca2d424c6e5358f9fee8dc9ee7477de76b50f848.

Reverted https://github.com/pytorch/pytorch/pull/130705 on behalf of https://github.com/atalman due to Failing internal CI ([comment](https://github.com/pytorch/pytorch/pull/130705#issuecomment-2230821876))
2024-07-16 12:57:11 +00:00
ca2d424c6e Tighten torch.library.infer_schema input types (#130705)
Made the following changes:
- mutates_args is now keyword-only and mandatory. This is to align with
  torch.library.custom_op (which makes it mandatory because it's easy to
  miss)
- op_name is now keyword-only. This helps the readability of the API
- updated all usages of infer_schema

This change is not BC-breaking because we introduced
torch.library.infer_schema a couple of days ago.

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130705
Approved by: https://github.com/yushangdi
2024-07-15 16:43:57 +00:00
7bbd6cf931 [custom_ops] Mark older custom ops prototypes as deprecated (#130032)
I've had at least one person try to call APIs from here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130032
Approved by: https://github.com/yushangdi, https://github.com/williamwen42
2024-07-03 21:11:05 +00:00
9972e5f447 Rename impl_abstract to register_fake, part 2/2 (#123938)
This PR renames the implementation details of register_fake to align
more with the new name. It is in its own PR because this is risky
(torch.package sometimes depends on private library functions and
implementation details).

Test Plan:
- tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123938
Approved by: https://github.com/williamwen42
2024-06-14 14:37:24 +00:00
dcfa7702c3 Flip default value for mypy disallow_untyped_defs [1/11] (#127838)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127838
Approved by: https://github.com/oulgen
2024-06-08 18:16:33 +00:00
8b08b0f340 [BE] enable ruff rule Q from flake8-quotes (#127713)
Enable [ruff rule `Q`](https://docs.astral.sh/ruff/rules/#flake8-quotes-q) from flake8-quotes. Fixes:

- [avoidable-escaped-quote (Q003)](https://docs.astral.sh/ruff/rules/avoidable-escaped-quote/#avoidable-escaped-quote-q003)
- [unnecessary-escaped-quote (Q004)](https://docs.astral.sh/ruff/rules/unnecessary-escaped-quote/#unnecessary-escaped-quote-q004)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127713
Approved by: https://github.com/ezyang
2024-06-02 23:25:26 +00:00
a8e17b2d4d Move schema inference to torch._library (#124199)
After this PR, we can delete torch._custom_op/torch._custom_ops (except
there are external libraries depending it).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124199
Approved by: https://github.com/albanD
ghstack dependencies: #124180, #124200, #124299, #124134
2024-04-19 17:56:30 +00:00
f0eb162730 Revert "Switch quantized_decomposed over to new custom ops API (#123454)"
This reverts commit 638729c0cdf3ce4274f4d68f8e46e5a1cd36cbe8.

Reverted https://github.com/pytorch/pytorch/pull/123454 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/123454#issuecomment-2051738976))
2024-04-12 13:14:59 +00:00
638729c0cd Switch quantized_decomposed over to new custom ops API (#123454)
We are taking API feedback. Changes:
- I removed some of the default values (they weren't being used).
- I was unable to convert the last op (which is essentially an
  autograd.Function registered as CompositeImplicitAutograd). That one
  is "incorrectly registered"; I punt fixing it to the future.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123454
Approved by: https://github.com/andrewor14
ghstack dependencies: #123453, #123578
2024-04-11 13:18:06 +00:00
8a5e7a01b5 [custom_op] Schema inference now includes default values (#123453)
If the function has default values, we should be able to do schema
inference and put the default values into the schema.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123453
Approved by: https://github.com/albanD
2024-04-11 13:18:02 +00:00
cd6c58baea [custom_ops] mutated_args -> mutates_args (#123437)
This seemed better, since when you're construction a custom op you need
to provide "the args that the custom op mutates".

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123437
Approved by: https://github.com/albanD
ghstack dependencies: #123108, #123109, #123110, #123129
2024-04-05 22:03:51 +00:00
81e7a7c955 Add mutated_args field to custom_op (#123129)
If provided, we:
- autogenerate an ADInplaceOrView implementation
- assume that no mutated inputs are returned as outputs. There are
  already aliasing runtime checks that check this.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123129
Approved by: https://github.com/albanD
ghstack dependencies: #123108, #123109, #123110
2024-04-05 22:03:51 +00:00
621fdc9db8 infer_schema can add alias annotations when passed a list of mutated args (#122343)
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122343
Approved by: https://github.com/ezyang
ghstack dependencies: #122319, #122320
2024-03-21 21:39:07 +00:00
639d6201b4 Expand the types infer_schema can infer (#122320)
This PR allows it to infer:
- None return as ()
- List[Tensor] as Tensor[]

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122320
Approved by: https://github.com/ezyang, https://github.com/soulitzer
ghstack dependencies: #122319
2024-03-21 21:39:07 +00:00
0dd78f1828 Add standalone tests for infer_schema (#122319)
We're gonna reuse this helper in the new python custom ops API. Given a
function with type annotations, `infer_schema(fun)` returns an inferred
schema.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122319
Approved by: https://github.com/ezyang, https://github.com/soulitzer
2024-03-21 21:39:04 +00:00
afabed6ae6 [inductor][custom ops] Add tag to custom ops to preserve stride orders in inductor (#117298)
fixes #116715

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117298
Approved by: https://github.com/eellison
2024-01-21 18:47:01 +00:00
10923f8720 Revert "[inductor][custom ops] Add tag to custom ops to preserve stride orders in inductor (#117298)"
This reverts commit 1967394690f144a7ba1717eccec977286cafe2da.

Reverted https://github.com/pytorch/pytorch/pull/117298 on behalf of https://github.com/huydhn due to Sorry for reverting you change but it is failing in MacOS 1967394690, may be due to a landrace ([comment](https://github.com/pytorch/pytorch/pull/117298#issuecomment-1901594120))
2024-01-20 02:14:58 +00:00
1967394690 [inductor][custom ops] Add tag to custom ops to preserve stride orders in inductor (#117298)
fixes #116715

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117298
Approved by: https://github.com/eellison
2024-01-20 01:37:28 +00:00
ad09d81694 Allow functionalization to work with optional mutable (#114803)
Summary: - Added functionalization to allow Optionals

Test Plan: CI tests.

Reviewed By: zou3519

Differential Revision: D51209981

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114803
Approved by: https://github.com/zou3519
2023-11-30 23:48:03 +00:00
d197f5c72b Remove unused call to inspect.stack() in torch/_custom_op/impl.py (#114698)
Summary: Fetching the stack isn't free and this variable isn't used. Let's not do the work.

Test Plan: Wait for tests

Differential Revision: D51629732

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114698
Approved by: https://github.com/zou3519, https://github.com/Skylion007
2023-11-29 19:33:52 +00:00
b7b2178204 [BE]: Remove useless lambdas (#113602)
Applies PLW0108 which removes useless lambda calls in Python, the rule is in preview so it is not ready to be enabled by default just yet. These are the autofixes from the rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113602
Approved by: https://github.com/albanD
2023-11-14 20:06:48 +00:00
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
bd0ea72b28 torch.library: Create helper function is_functional_schema (#111660)
I will need this again soon.

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111660
Approved by: https://github.com/soulitzer
2023-10-27 15:20:25 +00:00
35e48e262c [custom op] Use canonical API to constrain unbacked values (#108372)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108372
Approved by: https://github.com/angelayi, https://github.com/ezyang
2023-10-10 05:14:28 +00:00
0daa7d4815 [test][docs] Fix doctest warnings for syntax errors (#110517)
Fixes some syntax errors in doctest find in CI tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110517
Approved by: https://github.com/albanD
2023-10-05 00:00:06 +00:00
f8fcc54f70 Add torch.library.impl_abstract (#109912)
Changelog:
- torch.library.impl_abstract optionally accepts a torch.library.Library
  object. If passed in, then the lifetime of the registration is tied to
  the Library object.
- we've also changed torch.library.impl_abstract to work on all
  operators, including overloads.
- we refactored the `torch._custom_ops.*` and `torch._custom_op.*`
  impl_abstract APIs and put them under torch._library. This is the
  final resting place for them. I will follow-up with deleting
  all the `torch._custom_ops.*` stuff later.
- There is a new "SimpleOperatorRegistry" where we actually collect the
  abstract_impl. We will expand this to also hold the other
  torch._custom_ops.* APIs when we move those to torch.library

NB: Previously we had designed
`impl_abstract` assuming a very high-level Python-only custom op API.
We've revisited that since; now, impl_abstract works for all custom ops,
no matter python or C++, no matter the schema. The new refactored design
reflects this better.

Test Plan:
- existing and new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109912
Approved by: https://github.com/ezyang
2023-09-26 01:59:50 +00:00
8124a6c40c [TORCH_LIBRARY] Add impl_abstract_pystub (#109529)
We want users to be able to define custom ops in C++ but put the
abstract impl in Python (since it is easier to write them in Python and
the abstract impl better models device semantics and data-dependent
operators).

`m.impl_abstract_pystub(opname, python_module, context)` declares the
abstract_impl of the operator to exist in the given python module.
When the abstract_impl needs to be accessed (either via FakeTensor or
Meta), and it does not exist, the PyTorch Dispatcher will yell
with a descriptive error message.

Some details:
- We construct a new global AbstractImplPyStub mapping in
  Dispatcher.cpp. Read/write to this map is protected by the Dispatcher
  lock.
- We add a new Meta Tensor fallback kernel. The fallback errors out if there is
  no meta kernel, but also offers a nicer error message if we see that there is
  a pystub.
- We create a `torch._utils_internal.throw_abstract_impl_not_imported_error`
  helper function to throw errors. This way, we can throw different error
  messages in OSS PyTorch vs internal PyTorch. To invoke this from C++, we
  added a PyInterpreter::throw_abstract_impl_not_imported_error.

Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753/)

Differential Revision: [D49464753](https://our.internmc.facebook.com/intern/diff/D49464753)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109529
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
2023-09-22 04:55:36 +00:00
d9342cde6e custom ops: don't error if autograd input is a tensor subclass (#109248)
This is needed to allow the custom ops in our custom op autograd tests to accept `FunctionalTensor` arguments as inputs that we compute gradients for. Previously, custom ops would raise an error if you tried to pass in a tensor subclass when using autograd.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109248
Approved by: https://github.com/zou3519
ghstack dependencies: #109024
2023-09-20 04:37:31 +00:00
2932b0bf37 Extend impl_backward to be usable with torch.library operators (#106817)
- impl_save_for_backward/impl_backward only work for functional,
non-view schemas. We validate this.
- impl_save_for_backward/impl_backward raise if there already exists an
autograd implementation from torch.library / TORCH_LIBRARY.
- Operators constructed via custom_op receive an "autograd indirection
kernel". The "autograd indirection kernel" automatically pulls the
constructed autograd kernel out of a dict. When
impl_save_for_backward/impl_backward get used with torch.library
operators, we also register the "autograd indirection kernel" so we can
reuse the logic.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106817
Approved by: https://github.com/soulitzer
ghstack dependencies: #106799, #106800
2023-08-14 14:33:46 +00:00
db9a0cf689 Extend impl_backward to handle non-Tensor outputs (#106800)
Recall that the user must give us a backward function that accepts
`(ctx, saved, *grads)`, with one grad per output. Previously,
impl_backward only worked for functions that return one or more Tensors.

The new semantics are that if the output has:
- a TensorList, the backward function provided by the user will receive
a List[Tensor] of grads for that output.
- a number, the backward function provided by the user will receive
None as the grad.

Also recall that impl_backward is implemented by registering an
autograd.Function to the autograd dispatch key.
We needed to make the following changes:
- If an output is a TensorList, autograd.Function will ignore it. So we
need to tree-flatten it before returning it from the autograd.Function
- This means that the autograd.Function receives a flat list of grad
during the backwards pass. We need to tree-unflatten it into the correct
shape before passing it to the user-defined backward
- We modify the logic of output_differentiability. Only
Tensor/TensorList outputs can be marked as differentiable. If a
TensorList is marked as non-differentiable, then this is equivalent to
all Tensors in the list being non-differentiable. There is no
finer-grain control over this (to match derivatives.yaml).

Test Plan:
- There are new `numpy_split_copy` (returns TensorList) and
`numpy_split_copy_with_int` (returns (TensorList, int)) operators in
custom_op_db
- Added tests for output_differentiability into test/test_custom_ops.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106800
Approved by: https://github.com/soulitzer
ghstack dependencies: #106799
2023-08-14 14:33:46 +00:00
9fcce1baf1 [custom_op] Allow constructor to infer more types (#106799)
This expands the torch._custom_ops.custom_op API to be able to construct
operators that return (int, bool, float, Scalar, List[Tensor]) to make
it more in-line with our torch.library API.

NB: there needs to be updates to our custom_op autograd registration
API. For ease of review those changes will go in the next PR up but I
can squash if requested.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106799
Approved by: https://github.com/soulitzer
2023-08-14 14:33:43 +00:00
16b6873885 [custom_ops] extend impl_abstract to work with existing torch.library ops (#106088)
This PR extends impl_abstract to work with existing
torch.library/TORCH_LIBRARY ops.

There's a question of what to do if the user calls impl_abstract
and the op already has a registration for:
- DispatchKey::Meta. We raise.
- DispatchKey::CompositeImplicitAutograd. We raise.
- DispatchKey::CompositeExplicitAutograd. To be pragmatic, we don't
raise, since the user's CompositeExplicitAutograd might work for all
other backends but Meta.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106088
Approved by: https://github.com/soulitzer
ghstack dependencies: #106075, #106076
2023-08-08 13:53:20 +00:00
cebff39fad [custom_ops] make custom_ops.impl work on existing operators (#106076)
The design is that we construct a CustomOp object around the existing
operator and then use it to register things. It is totally OK if the
operator isn't functional (unlike torch._custom_ops.custom_op that can
only construct functional operators).

If the operator already has an implementation from a backend (either via
direct registration to e.g. DispatchKey::CPU, or an indirect
registration like CompositeImplicitAutograd/CompositeExplicitAutograd),
we raise an error.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106076
Approved by: https://github.com/soulitzer
ghstack dependencies: #106075
2023-08-08 13:53:20 +00:00
60a4ac3068 [custom_ops] Block overload names (#106075)
These are valid with the torch.library API, but (1) they add complexity
and (2) I have never seen a custom op actually use an overload name
before. For simplicity we block all overloads.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106075
Approved by: https://github.com/soulitzer
2023-08-08 13:53:18 +00:00
c5b9dc1f40 Optimize stack frame inspection in torch._custom_op.impl:CustomOp._register_impl (#105940)
Summary: This is surprisingly expensive when the stack is deep. We can instead just process the specific stack frame that's relevant -- it's much faster.

Test Plan:
```
import inspect
import sys
import time

def make_deep_stack(fn, n: int = 10):
    if n > 0:
        return make_deep_stack(fn, n - 1)

    return fn()

def full_stack():
    return inspect.stack()[1][3]

def via_current_frame():
    return inspect.getframeinfo(sys._getframe(1))[2]

start = time.perf_counter()
for _ in range(1000):
    make_deep_stack(full_stack)
print(f"full_stack took {time.perf_counter() - start}s")

start = time.perf_counter()
for _ in range(1000):
    make_deep_stack(via_current_frame)
print(f"via_current_frame took {time.perf_counter() - start}s")

> full_stack took 31.788201928138733s
> via_current_frame took 2.33455612603575s
```

Differential Revision: D47674015

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105940
Approved by: https://github.com/zou3519
2023-07-31 15:49:33 +00:00
dad65d09f2 Update custom op API (#105947)
As described in
https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk/edit

This PR changes the CustomOp API to be private and adds new public
wrappers around it so that the user does not need to know about the
"CustomOp" object. We've effectively changed the "CustomOp" object to be
some metadata about the operator that the user does not directly
interact with.

The "updated custom op API" is in torch._custom_ops. Pending good customer
feedback, we will promote this module to torch.custom_ops.

NB: I cannot move around the older torch._custom_op APIs yet because
people are already using them.

Test Plan:
- I changed all of our tests to use the new `torch._custom_ops` module
instead of the old CustomOp API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105947
Approved by: https://github.com/soulitzer
2023-07-28 13:30:58 +00:00
79c5e33349 [BE] Enable ruff's UP rules and autoformat nn/ mps/ and torch/ (#105436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105436
Approved by: https://github.com/malfet, https://github.com/albanD
2023-07-21 07:38:46 +00:00
3897c479af Add API to construct the functional variant of an op (#102293)
`register_functional_op`:
- constructs the functional variant of an op
- registers a functionalization kernel to the op

To get this to work:
- `register_functional_op` makes assumptions that it checks about the
op's schema. In particular, the op is not allowed to return anything it
mutates. We can relax these constraints in the future.
- We add a "boxed" python functionalization kernel that handles this
case.

I'm not actually sure (or convinced) this should be public API or how
it should work. If we want this to be public, then it should probably be
a torch.library API, but does that also mean we should give the same
lifetime guarantees? If so, then it would be up to the user to construct
a Library object to actually register the functional variant onto.

Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102293
Approved by: https://github.com/bdhirsh
2023-06-02 13:36:50 +00:00
74f10b9ea5 Switch most Python RAII guard usages to context manager (#102642)
There are some I can't easily switch due to reasons like:
- Dynamo modelling the guard
- BC concerns (for torch.autograd.set_multithreading_enabled)

Test Plan:
- existing tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102642
Approved by: https://github.com/albanD
2023-06-01 16:28:37 +00:00
fc31b3a106 Allow existing "Python RAII guards" to be used as context managers (#102579)
This PR adds a `py_context_manager_DEPRECATED` that converts a C++ RAII
guard to an object that may be either used as Python context manager or
as a "Python RAII guard".

We don't convert all of them to Python context manager only due to BC
reasons; people in OSS and internally actually rely on these APIs and I
don't want to break them. We are justified in breaking BC if we wanted
to, but it seemed like too much work for not a lot of gain.

The API is postfixed with "DEPRECATED" to indicate that people should
really use `py_context_manager` (converts C++ RAII guard to Python
context manager) instead.

Test Plan:
- this PR converts all PyTorch usages of _AutoDispatchBelowAutograd to
context manager. I can do the rest in follow-ups.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102579
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2023-05-31 19:55:38 +00:00
723f111545 [custom_op] explicit autograd API (#101824)
This PR adds an explicit API for registering a backward formula for a
CustomOp. In the end state, we will likely have this explicit API and a
magic API (which is sugar on top of an explicit API), since different
parties of users prefer different ones.

Concretely, to define a backward formula for a CustomOp:
- a user must provide us a "save for backward" function that accepts
(inputs, output) and returns exactly what they want saved for backward
- a user must provide us a "backward" function that accepts
(ctx, saved, *grads) and returns us the grad_inputs. The grad_inputs
are returned as a dict mapping str to a gradient.
Please see the changes in custom_op_db.py for examples of the API.

There are a number of pieces to this PR and I'm happy to split it if it
helps. They are:
- The actual APIs for specifying the two functions
(impl_save_for_backward, impl_backward)
- The autograd kernel: we take the functions the user give us and
construct an autograd.Function object that we then register to
the Autograd dispatch key
- Indirection for the autograd kernel. We add a layer of indirection so
that one can swap out the autograd kernel. This is necessary because by
default, we register an "autograd not implemented" kernel as the
Autograd implementation but then swap it for the actual kernel when the
user provides it.

Test Plan:
- We apply this API to give backward formulas for things in
custom_op_db. We then hook up custom_op_db to the Autograd OpInfo tests.
- Various tests in test_python_dispatch.py to check error cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101824
Approved by: https://github.com/ezyang
2023-05-23 18:31:29 +00:00
8487105fae [custom_op] Create a new torch._custom_op namespace (#101823)
torch/custom_op.py is getting long, and the autograd pieces are going to
make it even longer. I'm planning on just organizing the files under
a torch/_custom_op folder.

Note that the imports now look a bit crazy (from torch._custom_op.impl
import...) but they will look more OK when we figure out the plan to
make custom_op public (coming later).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101823
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/bdhirsh
2023-05-23 18:31:29 +00:00