Fix: #141251
This PR adds a few static guard checks when decomposing and lowering the `slice`
operation, so that we avoid adding unnecessary guards. Specifically, when clamping the end
values.
In summary, the changes are:
- `slice` dynamo decomposition: checks `end >= sizes[dim]` statically. If we don't know
that, the following guard ensures that we (don't) need clamping.
- `evaluate_min` inductor `sizevar` function: checks whether we can solve it statically or
not, before actually creating a new guard.
The latter had to be changed because `evaluate_min` (called by `ir.SliceView` constructor)
would always try to create a guard based on the hints operation result. However, if both
`left` and `right` hints were true, it would default to `left <= right` guard. By checking
the guards statically before, we can avoid that.
```python
N = 16
@torch.compile(backend="inductor", dynamic=False, fullgraph=True)
def fn(x):
splits = torch.ops.aten.split.Tensor(x, N)
first = splits[0]
return torch.ops.aten.slice.Tensor(first, 0, 0, N)
x = torch.arange(N)
torch._dynamo.mark_dynamic(x, 0)
fn(x)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142372
Approved by: https://github.com/ezyang
Fixes https://github.com/pytorch/pytorch/issues/141332
`F.logsigmoid` will return two outputs: `output` and `buffer`.
For `F.logsigmoid` cpu path, it will use buffer to store some intermediate values and use them when computing gradients, so it returns a `buffer` tensor with nonzero size. For cuda and xpu paths, buffer is useless, so the `buffer ` tensor size of xpu `F.logsigmoid` will be zero, just like cuda. The root cause of the issue is that the codes in `decompositions.py` (ref:https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2803) only handle the cuda cases, when the a fake tensor with device is xpu run to here, it will use the cpu path and return a `buffer` with nonzero size, which is conflict to the implementation of intel xpu concrete tensor. Therefore this pr add conditions to handle xpu cases. Make sure the two returned buffer sizes match each other.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141333
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/ezyang
Previously the split decomp would return the input when there were no splits. this errors in torch.compile (or FakeTensorMode) with :
> RuntimeError: View operation returned a tensor that is the same as the input base tensor. This is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). As a user, you could have made a mistake implementing __torch_dispatch__ or a Python operator decomposition or meta registration; if that's not the case, please report a bug to PyTorch or the backend you are using.
Fix for https://github.com/pytorch/pytorch/issues/133394
Differential Revision: [D65635070](https://our.internmc.facebook.com/intern/diff/D65635070)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140065
Approved by: https://github.com/bdhirsh
Previously the split decomp would return the input when there were no splits. this errors in torch.compile (or FakeTensorMode) with :
> RuntimeError: View operation returned a tensor that is the same as the input base tensor. This is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). As a user, you could have made a mistake implementing __torch_dispatch__ or a Python operator decomposition or meta registration; if that's not the case, please report a bug to PyTorch or the backend you are using.
Fix for https://github.com/pytorch/pytorch/issues/133394
Differential Revision: [D65635070](https://our.internmc.facebook.com/intern/diff/D65635070)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140065
Approved by: https://github.com/bdhirsh
**Summary**
In the case of LLaMA2, for a linear operation with an activation size of `(4, 1, 4096)` and a stride of `(4096, 128, 1)` which has been decomposed into `matmul`. And the decomposition of `matmul` results in `bmm` due to a strict continuity check. We can align the continuity check with ATen by skip dim of size 1 to enable decomposition into `mm` instead.
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_input_non_contiguous_3D_wo_bias
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139172
Approved by: https://github.com/jgong5, https://github.com/ezyang
Summary:
# context
* enable the `_get_user_embeddings` function
* run failed at P1610151892
```
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
GuardOnDataDependentSymNode: Could not guard on data-dependent expression u22 <= 0 (unhinted: u22 <= 0). (Size-like symbols: u22)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Potential framework code culprit (scroll up for full backtrace):
File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/38472faba4e3e6c1/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 1692, in native_layer_norm_backward
if M <= 0 or N <= 0:
```
```
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
if M <= 0 or N <= 0:
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
)
```
# changes
* use guard_size_oblivious since the new_zeros return is kind of optimization, shouldn't impact the correctness of the follow up code logic.
* the size `ret[i][j]` could be zero, so the change in V1 isn't valid
* for more details: [post](https://fb.workplace.com/groups/6829516587176185/permalink/8003616173099548/)
```
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
```
# past
* found `u22` was introduced at
```
def _wait_impl(self) -> List[List[int]]:
# Can not use is_torchdynamo_compiling(), as every such condition should be independent for compilation with graph breaks.
if isinstance(self._splits_awaitable, dist.Work):
self._splits_awaitable.wait()
ret = self._output_tensor.view(self.num_workers, -1).T.tolist() # <------ u22 introduced here
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
for i in range(len(ret)):
for j in range(len(ret[i])):
torch._check_is_size(ret[i][j]) # <---------- my question: why the _check_is_size isn't enough??
torch._check(ret[i][j] > 0) # <------ added by diff V1
```
Test Plan:
# run command
```
TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2 2>&1 | tee -a `tagT`.`tagH`.log
```
# results
* before
**without enabling `_get_user_embeddings`**
[14 Failures and Restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp2eNI7p/failures_and_restarts.html)
log: P1610151892
{F1889387940}
* V1
enable `_get_user_embeddings`
with `torch._check(ret[i][j] > 0)`
[13 Failures and Restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp6J1iY9/failures_and_restarts.html)
{F1889388378}
* V2
enable `_get_user_embeddings`
with `if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):`
[tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpFhZZyC/index.html)
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
Differential Revision: D63424929
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136798
Approved by: https://github.com/ezyang
When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors.
It can be especially problematic when after such function inplace allreduce is performed.
Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned.
Test included in this PR reproduces the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134568
Approved by: https://github.com/zou3519
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary.
Test Plan: CI
Differential Revision: D62278378
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297
Approved by: https://github.com/drisspg
## Description
Create decomposition of _unsafe_index_put (non-core aten) that turns it into index_put (core aten)
## Testing
Phi3 mini + LoRA model successfully passed `to_edge` after failing due to a non-core aten `unsafe_index_put` getting introduced in a decomposition during joint graph calculations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133365
Approved by: https://github.com/pianpwk
Summary:
# context
* when running an IG FM training with PT2 we found there are a few graph break due to torch.diff call in [jagged_tensor.py](https://fburl.com/code/cwssxabc)
```
_length: List[int] = (
_length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key)
if variable_stride_per_key
else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
)
```
* look into the failure, we found the TORCH_CHECK in diff should be TORCH_SYM_CHECK
* slice_forward error: df3d7729e, [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxXZ2em/index.html)
```
RestartAnalysis
Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time. You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs. Could not guard on data-dependent expression ((5*u37 + u38)//(u37 + u38)) < 0 (unhinted: ((5*u37 + u38)//(u37 + u38)) < 0). (Size-like symbols: u38, u37)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Potential framework code culprit (scroll up for full backtrace):
File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 771, in slice_forward
if end_val < 0:
```
* after this diff: [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpAhv2Sh/failures_and_restarts.html)
Test Plan:
# command
* run model
```
TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2
```
* generate tlparse
```
tlparse `ls -t /var/tmp/tt/* | head -1`
```
Reviewed By: ezyang
Differential Revision: D56339251
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133740
Approved by: https://github.com/ezyang
# Summary
Changes the stance of SDPA on what to do for fully masked out rows
## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963
These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617
Can be paraphrased as follows:
When passing in fully masked out rows, attention becomes ambiguous. We have two main options:
1. Uniformly attend to all values:
```python
scores[masked_out_rows] = 1 / len(row)
out[masked_out_rows] = 1 / len(row) * value
```
2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
```python
output[fully_masked_rows] = NaN
```
We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
[ nan, nan, nan, nan]])
```
Those pesky NaNs are back!
## Why do we see NaNs today?
The core of the problem revolves around using softmax function in sdpa:
```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```
## Quick Aside: Masking in Attention
Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.
We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.
## Alternative Approaches
If we use a very large negative number instead of -inf:
```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564, 0.1613, -0.0486],
[ 0.0000, 0.0000, 0.0000, 0.0000]])
```
This would bring us back into a better state.
## A Third Option
We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.
This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```
**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.
## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel
_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.
Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?
The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`
Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`
If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this
```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.
## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131060
Approved by: https://github.com/jbschlosser
This PR makes sure all current tests in the sparsity export test suite pass. Note that there will probably be anecdotal cases that need fixing after this, but the general idea of preserving sparsity metadata has been completed.
Fixes: https://github.com/pytorch/pytorch/issues/117188
```
$ PYTORCH_TEST_WITH_DYNAMO=0 python test/export/test_sparse.py ........................................................................................................................................................
----------------------------------------------------------------------
Ran 152 tests
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132690
Approved by: https://github.com/ezyang
`aten._to_copy` can receive a python number as input. This occurs in
torch.compile support for vmap (see #130188). Previously, this would
raise an assertion error. This PR changes it so that if we see a python
number, we call torch.scalar_tensor on it first (h/t @bdhirsh).
Fixes#130362Fixes#130188
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130381
Approved by: https://github.com/Chillee