Thanks @awgu for raising this issue and the small repro
From offline discussion with @albanD, in the case where a forward returns multiple outputs with different devices, we'd want to select the ready queue based on the device of the first one. Even though this is somewhat arbitrary, we prefer this over deciding which ready queue to push based on whichever input buffer's we happen to compute last, which can vary depending on more factors and thus be harder to reason about. This is in theory bc-breaking, but it seems unlikely that someone would depend on this behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135633
Approved by: https://github.com/albanD
The regression from https://github.com/pytorch/pytorch/issues/132281 pinpoints e4ace1a396 as the cause. The main delta that commit introduces is that we now manually check `is_inference()` and call `increment_version()` (a pybind call) on every mutated input tensor to the graph.
This PR attempts to reduce overhead a bit by bundling up all of those checks into a single pybind call, by:
(1) updating `torch.autograd.graph.increment_version()` to accept a `Union[Tensor, List[Tensor]]`
(2) updating its semantics to no-op if you pass in a tensor with no version counter, instead of erroring
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132652
Approved by: https://github.com/albanD
Inductor would like a way to have activations that do not escape the backward graph marked as "donated", so we can re-use their memory during memory planning here: https://github.com/pytorch/pytorch/pull/130580
For this to be safe though, we need to know at runtime that autograd does not plan to retain the current autograd graph (either for another call to .backward() later, or if double backward is being used). In the linked PR, the current plan is to error when we detect this situation, and ask the user to turn off the donated buffer config (although if/once we get to the point of always delaying backward compilation to runtime, we can just wait until we know the runtime value to compile).
There isn't a way to know if the currently running backward is run with `retain_graph=True` from python - @soulitzer helped me figure out where to grab it so I added a python binding for it under `ctx.is_retain_graph()`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131038
Approved by: https://github.com/soulitzer
This is useful for splitting grad to run in two parts while preserving intermediates:
<details>
<summary>
Click to see code
</summary>
```python
import collections
import weakref
from torch.autograd.graph import GradientEdge
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
def reverse_closure(roots, target_nodes):
# Recurse until we reach a target node
closure = set()
actual_target_nodes = set()
q: Deque = collections.deque()
for node in roots:
if node is not None and node not in closure:
closure.add(node)
q.append(node)
while q:
node = q.popleft()
reverse_edges = node.metadata.get("reverse_edges", [])
for holder_ref, idx in reverse_edges:
ref = holder_ref()
if ref is not None:
raise RuntimeError("Reverse graph is no longer alive")
fn = ref.node
if fn in closure or fn is None:
continue
if fn in target_nodes:
actual_target_nodes.add(fn)
continue
closure.add(fn)
q.append(fn)
return closure, actual_target_nodes
# Enable weak pointer
class Holder():
def __init__(self, node):
self.node = node
# TODO: use weak references to avoid reference cycle
def construct_reverse_graph(roots):
q: Deque = collections.deque()
root_seen = set()
reverse_graph_refs = []
for node in roots:
if node is not None and node not in root_seen:
q.append(node)
root_seen.add(node)
while q:
node = q.popleft()
for fn, idx in node.next_functions:
if fn is not None:
# Don't necessarily need to store on the graph
reverse_edges = fn.metadata.get("reverse_edges", [])
if len(reverse_edges) == 0:
q.append(fn)
holder = Holder(node)
holder_ref = weakref.ref(holder)
reverse_graph_refs.append(holder)
reverse_edges.append((holder_ref, idx))
fn.metadata["reverse_edges"] = reverse_edges
return reverse_graph_refs
def get_param_groups(inputs, params):
inputs_closure, _ = reverse_closure(inputs, set())
param_groups = dict() # keyed on intermediates
for i, param in enumerate(params):
closure, intersected = reverse_closure([param], inputs_closure)
param_group = {
"params": set([param]),
"intermediates": set(intersected),
}
for input_node in intersected:
existing = param_groups.get(input_node, None)
if existing is not None:
existing["params"] = existing["params"].union(param_group["params"])
existing["intermediates"] = existing["intermediates"].union(param_group["intermediates"])
param_group = existing
else:
param_groups[input_node] = param_group
# Sanity check: union of all param_groups params should be equal to all params
union_params = set()
seen_ids = set()
unique_param_groups = []
for param_group in param_groups.values():
if id(param_group) not in seen_ids:
seen_ids.add(id(param_group))
unique_param_groups.append(param_group)
union_params = union_params.union(param_group["params"])
assert union_params == set(params)
return unique_param_groups
def compute_grads_only_inputs2(roots, inps, weights):
root_grad_fns = list(map(_get_grad_fn_or_grad_acc, roots))
inp_grad_fns = list(map(_get_grad_fn_or_grad_acc, inps))
weight_grad_fns = list(map(_get_grad_fn_or_grad_acc, weights))
reverse_graph_refs = construct_reverse_graph(root_grad_fns)
param_groups = get_param_groups(inp_grad_fns, weight_grad_fns)
del reverse_graph_refs
for param_group in param_groups:
for i, intermediate in enumerate(param_group["intermediates"]):
def get_hook(param_group, i):
def hook(grad_inputs):
if param_group.get("grads", None) is None:
param_group["grads"] = [None] * len(param_group["intermediates"])
param_group["grads"][i] = grad_inputs
return hook
# These are always "split" nodes that we need to recompute, so
# save their inputs.
intermediate.register_prehook(get_hook(param_group, i))
dinputs = torch.autograd.grad((out,), inputs=tuple(inps), grad_outputs=(torch.ones_like(out),), retain_graph=True)
return dinputs, param_groups
def compute_grads_only_weights2(user_weights, param_groups):
all_dweights = dict()
for param_group in param_groups:
# TODO: Handle case where intermediate can have multiple outputs
intermediate_edges = tuple(GradientEdge(i, 0) for i in param_group["intermediates"])
weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"])
assert all(len(g) == 1 for g in param_group["grads"])
# [NEW!] Able to pass a GradientEdge to autograd.grad as output
# We do not need to retain_graph because... guarantee no overlap?
print("trying to execute: ", intermediate_edges, weights_edges)
dweights = torch.autograd.grad(intermediate_edges, weights_edges, grad_outputs=sum(param_group["grads"], tuple()))
for w, dw in zip(param_group["params"], dweights):
all_dweights[w] = dw
# return grads in the original order weights were provided in
out = []
for w in user_weights:
grad_acc = _get_grad_fn_or_grad_acc(w)
out.append(all_dweights[grad_acc])
return tuple(out)
```
</details>
```python
import torch.nn as nn
# Setup
mod1 = nn.Linear(10, 10)
mod2 = nn.Linear(10, 10)
a = torch.rand(10, requires_grad=True)
weights = tuple(mod1.parameters()) + tuple(mod2.parameters())
inps = (a,)
out = mod2(mod1(a))
class LoggingTensorMode(torch.utils._python_dispatch.TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print(f"{func.__module__}.{func.__name__}")
return rs
print(" -- SPLIT -- ")
# Compute gradients in two parts
with LoggingTensorMode():
print("PART 1")
dinputs, state = compute_grads_only_inputs2((out,), inps, weights)
print("PART 2")
dweights = compute_grads_only_weights2(weights, state)
out = mod2(mod1(a))
print(" -- REF -- ")
# Compare with reference
with LoggingTensorMode():
ref_all_gradients = torch.autograd.grad(out, inputs=tuple(inps) + weights, grad_outputs=(torch.ones_like(out),))
for actual, ref in zip(dinputs + dweights, ref_all_gradients):
print(torch.allclose(actual, ref))
```
<img width="598" alt="image" src="https://github.com/pytorch/pytorch/assets/13428986/3681b8a7-3ab4-4d1d-a836-abef6913e671">
```
PART 1
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.ones_like.default
V0603 10:17:21.590878 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1ee160> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591204 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591578 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x100d7ae50> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
V0603 10:17:21.591747 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a60> with grad_outputs: [f32[10]]
torch._ops.aten.view.default
V0603 10:17:21.591834 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
V0603 10:17:21.591922 8300067520 torch/autograd/graph.py:751] Executing: <ViewBackward0 object at 0x12a1e4a90> with grad_outputs: [f32[1, 10]]
torch._ops.aten.view.default
PART 2
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1e4bb0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a21b130>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b7c0>, output_nr=0))
V0603 10:17:21.592223 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1e4bb0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.592421 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a1cad60> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
trying to execute: (GradientEdge(node=<AddmmBackward0 object at 0x12a1ee0d0>, output_nr=0),) (GradientEdge(node=<AccumulateGrad object at 0x12a1e41c0>, output_nr=0), GradientEdge(node=<AccumulateGrad object at 0x12a21b670>, output_nr=0))
V0603 10:17:21.593481 8300067520 torch/autograd/graph.py:751] Executing: <AddmmBackward0 object at 0x12a1ee0d0> with grad_outputs: [f32[1, 10]]
torch._ops.aten.t.default
torch._ops.aten.mm.default
torch._ops.aten.t.default
torch._ops.aten.sum.dim_IntList
torch._ops.aten.view.default
V0603 10:17:21.593750 8300067520 torch/autograd/graph.py:751] Executing: <TBackward0 object at 0x12a21b2b0> with grad_outputs: [f32[10, 10]]
torch._ops.aten.t.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
torch._ops.aten.view.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127766
Approved by: https://github.com/albanD
In this PR:
- Ensure that if a tensor not requiring grad is saved for backward unpacking does not trigger a detach (unless the user installs a saved tensor pack hook that returns a tensor requiring grad).
- Update non-reentrant checkpoint to also no longer detach for this case.
Alternatives:
- For custom autograd Function, you could directly save on ctx to work around this, but that would not work for when we switch to using custom ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127959
Approved by: https://github.com/YuqingJ
ghstack dependencies: #125795, #128545, #129262
Update ruff to 0.5.0 so we can enable all the some of the new checks I've been wanting to add to the codebase. First just updating the code to comply with some rule changes and a couple minor API changes / deprecations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129744
Approved by: https://github.com/ezyang
Fixes https://github.com/pytorch/pytorch/issues/128611
We detach using tensor_data, which already preserves the version counter, so there is no reason to save it prior to unpacking:
```
at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const {
TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor");
auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
/*version_counter=*/self.unsafeGetTensorImpl()->version_counter(),
/*allow_tensor_metadata_change=*/
self.unsafeGetTensorImpl()->allow_tensor_metadata_change());
return at::Tensor(self_impl_copy);
}
```
This changes the behavior when hooks are involved:
- Previously, if you had a hook that replaced the saved tensor with an entirely new tensor, we would've smashed the saved version counter onto that during unpack, which is not quite correct because the tensor returned by user's pack hook is not necessarily aliased to the tensor originally being saved (unlikely), and even if it were, the version counter would already be shared, if the user did their operations not in inference mode (unlikely).
- In this PR, we restore the version counter using the version counter from the unpack hook's output.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128545
Approved by: https://github.com/albanD
ghstack dependencies: #125795
### bc-breaking for existing users of the private API:
- Existing policy functions must now change their return value to be [CheckpointPolicy](c0b40ab42e/torch/utils/checkpoint.py (L1204-L1230)) Enum instead of bool.
- To restore previous behavior, return `PREFER_RECOMPUTE` instead of `False` and `{PREFER,MUST}_SAVE` instead of `True` depending whether you prefer the compiler to override your policy.
- Policy function now accepts a `ctx` object instead of `mode` for its first argument.
- To restore previous behavior, `mode = "recompute" if ctx.is_recompute else "forward"`.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `create_selective_checkpoint_contexts `. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit
Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.
In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)
Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)
Tensor object preservation
- ~We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object.~ UPDATE: We guarantee that if a tensor is of non-differentiable dtype AND it is not a view, and it is saved, then what you get out is the same tensor object. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.
Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something that should be documented as part of public API. We call the policy function for all ops except ~~detach~~ UPDATE : metadata ops listed in `torch.utils.checkpoint.SAC_IGNORED_OPS`) because these ops may be called a different number of times by AC itself between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
FIXES#113263. Same idea as in https://github.com/pytorch/pytorch/pull/113417, but we need a more intrusive C API to silently nop default saved tensor hooks, in order to support user-code that use torch.autograd.disable_saved_tensors_hooks (see test_unpack_hooks_can_be_disabled). We mock the output of get_hooks while leaving push/pop untouched.
For compiled autograd, we're firing pack hooks once and unpack hooks twice right now, I'll look into this separately from this issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123196
Approved by: https://github.com/soulitzer
Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit
Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.
In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)
Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)
Tensor object preservation
- We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object. If the tensor does require grad, we must detach to avoid creating a reference cycle. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.
Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something documented part of public API. We call the policy function for all ops except detach because detach is itself called a different number of times by AC between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.
"bc-breaking" for existing users of the private API:
- Existing policy functions must now change their return value to use the Enum.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `gen_selective_checkpoint_context_fn`. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
autograd already marks nodes as needed or not before calling calling compiled autograd. so our worklist already skips nodes not specified in the `inputs` kwarg.
For the .backward(inputs=) case, I'm keeping the grads as outputs, just like for .grad(inputs=), this is to still guard on graph_output when we collect the nodes. This does not get DCE'd rn, and is ignored in the post graph bytecode.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128252
Approved by: https://github.com/jansel
Most commonly CPU scalars used for philox random seed. Right now, any cpu input will skip cudagraphing the entire graph. We need both the traced graph and the runtime inputs to be cudaified.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125382
Approved by: https://github.com/jansel
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
Resolves#126888
- #126888
This PR is split from PR #126898.
- #126898
------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.
Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.
UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.
Resolves#126888
- #126888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.
#suppress-api-compatibility-check
Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123247
Approved by: https://github.com/aaronenyeshi
Adds a ruff lint rule to ban raising raw exceptions. Most of these should at the very least be runtime exception, value errors, type errors or some other errors. There are hundreds of instance of these bad exception types already in the codebase, so I have noqa'd most of them. Hopefully this error code will get commiters to rethink what exception type they should raise when they submit a PR.
I also encourage people to gradually go and fix all the existing noqas that have been added so they can be removed overtime and our exception typing can be improved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124570
Approved by: https://github.com/ezyang
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.
Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0
| Repository | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7 | 251.8 | 351.1 | 274.9 |
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.
#suppress-api-compatibility-check
Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123247
Approved by: https://github.com/aaronenyeshi, https://github.com/gujinghui
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
* `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
* This ops is implemented on the Python side using torch.library so we can return a subclass instance
* `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
* The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
* `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
* `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
* Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)
With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.
Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
This PR:
* Introduces an ATen op for creating true jagged views from a dense values buffer
* `_nested_view_from_jagged(values, offsets, lengths, ragged_idx, dummy)`
* This ops is implemented on the Python side using torch.library so we can return a subclass instance
* `jagged_from_list()` now uses this instead of the old autograd.Function `NestedViewFromBuffer`
* The latter op is used for non-contiguous JTs returned via `torch.nested.narrow()`
* `dummy` is an awful hack to ensure that `NestedTensor.__torch_dispatch__()` is invoked for our view
* Introduces an ATen op for accessing the `values` component of an NT via a view
* `_nested_get_values(nt)`
* **Removes** the autograd.Functions `ViewNestedFromBuffer` and `ViewBufferFromNested` in favor of `nested_from_values_offsets()` / `nested_from_values_offsets_lengths()` and `nt.values()`, respectively.
* Changes test code to prefer `as_nested_tensor()` over `jagged_from_list()` directly
* Similarly, avoid `buffer_from_jagged()`, preferring `values()`
* Depends on general subclass view fake-ification on the PT2 side (handled solely in previous PRs in the stack)
With these changes, the semantics of jagged layout NTs are such that they are considered a true view of the underlying `values` buffer. This means views of jagged NTs are views of the underlying buffer as well, simplifying some handling.
Differential Revision: [D54269922](https://our.internmc.facebook.com/intern/diff/D54269922)
Co-authored-by: voznesenskym <voznesenskym@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113279
Approved by: https://github.com/ezyang
This does not introduce a new test but is tested by checking that all the classes we already have still behave as before now that they don't explicitly disable torch_function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120632
Approved by: https://github.com/ezyang
- There are no usages of this internally.
- There are very few usages of this in OSS (most of these are forks of old
repositories).
- This flag doesn't do anything.
We're deprecating it to prevent confusion. I will delete it immediately
after the branch cut.
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121413
Approved by: https://github.com/albanD, https://github.com/soulitzer
# Motivation
According to [[RFC] Intel GPU Runtime Upstreaming](https://github.com/pytorch/pytorch/issues/114842), the 5th runtime component we would like to upstream is `Guard`. We will cover device guard and stream guard in this PR.
# Design
Device guard is used mainly for op dispatcher in PyTorch. Currently, PyTorch already has a device guard abstraction `c10::impl::DeviceGuardImplInterface`. In our design, we will introduce an `XPUGuardImpl` class inherits from `c10::impl::DeviceGuardImplInterface`. Register `XPUGuardImpl` to PyTorch after we implement the device switch management mechanism in `XPUGuardImpl`. Besides, we will introduce `XPUGuard`, `OptionalXPUGuard`, `XPUStreamGuard`, and `OptionalXPUStreamGuard`. They are all following the design of CUDA's counterpart. The corresponding C++ file should be placed in c10/xpu/ folder.
# Additional Context
It is unnecessary to add `Guard` code to PyTorch frontend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118523
Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/jgong5, https://github.com/malfet
ghstack dependencies: #120315