Summary:
`None` and `Ellipsis` in multi-dimensional indexing was previously not covered.
Moreover, we introduce a small optimization for `slice(None)` and a passthrough when symints do not appear in the indexing.
The remaining case is where indexing is by tensor, which is fairly complicated; we passthrough in that case.
Test Plan:
added tests
Rollback Plan:
Differential Revision: D77943929
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157821
Approved by: https://github.com/pianpwk
Summary:
This DIFF is to fix the following issue:
In python source code for CompiledFxGraph,the FX graph segment for the Triton kernel is broken. For example, the following function
def fn(a, b, c):
x = torch.nn.functional.linear(a, b)
x = x.sin()
x = x.t() + c
return x
Inductor compiled this FX graph into two nodes: the first one is mm, the second one is a triton kernel for sin + transpose + add. The FX graph segment for the triton kernel is like the following:
Graph fragment:
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %arg2_1), kwargs = {})
Basically only "add" node in the FX graph.
The root cause is function caffe2/torch/_inductor/utils.py:gather_origins does not detect the realized node correctly.
To fix this issue, the IRNode is checked if it is one of the following IRNode:
ir.ComputedBuffer,
ir.InputsKernel,
ir.InputBuffer,
ir.ReinterpretView,
ir.TemplateBuffer,
If it is one of them, it is realized, otherwise, it is not.
Test Plan:
buck2 run mode/opt caffe2/test/inductor:provenance_tracing -- caffe2.test.inductor.test_provenance_tracing.TestProvenanceTracingArtifact.test_triton_kernel_to_post_grad_tracing_cuda
Rollback Plan:
Differential Revision: D77748371
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157578
Approved by: https://github.com/mlazos
@animesh pointed out using whitelist for strides can result in confusing graphs as follows
```
s60: "Sym(s60)", L_hidden_states_: "bf16[1, 4096, 3072][s60, 3072, 1]cuda:0"
```
We probably want to capture the relationship between sizes and strides anyways so let's make it so the whitelist only makes the sizes dynamic. That same graph now looks lik ethis
```
L_hidden_states_: "bf16[1, 4096, 64][262144, 64, 1]cuda:0"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157960
Approved by: https://github.com/pianpwk
Summary:
- adding mmap-ing
- more efficient writing in larger chunks
latency from ~150s to ~6s for simple row-wise consolidation of a 7gb model sharded across 4 ranks
Test Plan:
ran consolidation with the following code:
```
from torch.distributed.checkpoint._consolidate_hf_safetensors import consolidate_safetensors_files
import time
start_time = time.time()
consolidate_safetensors_files(base_path, consolidated_path)
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")
```
With the old code this was taking a couple minutes and this is now down to ~6s.
Internal users can find the tensor shards in the manifold path: manifold://ankita_test_bucket/tree/safetensors
Rollback Plan:
Differential Revision: D77960054
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157936
Approved by: https://github.com/teja-rao, https://github.com/pradeepfn
If the final output file is in remote storage, then create a local temp directory to write the files and upload the files to the remotes storage after they are written.
Add a new config to the storage writer, `enable_consolidation`, so we don't need to rely on the presence of the `consolidation_output_path` to decide if consolidation is enabled. If `enable_consolidation` is True and `consolidation_output_path` isn't provided, the consolidated safetensors will be added to the same path as the sharded ones.
Differential Revision: [D77554585](https://our.internmc.facebook.com/intern/diff/D77554585/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157371
Approved by: https://github.com/pradeepfn
Fixes#157195
### Summary:
Fixed Issue 157195 by adding a new error message for torch.binomial in **aten/src/ATen/native/Distributions.cpp**
### Explanation
According to the issue,
```
import torch
torch.binomial(torch.tensor([10]).long(), torch.tensor([0.5]))
```
`RuntimeError: Found dtype Float but expected Long`
It looks like we are getting a Tensor error rather than a binomial function error. Since the error is coming from **pytorch/aten/src/ATen/TensorIterator.cpp**, it seems like it is trying to align the tensor data to the same datatype for smooth tensor computations instead of giving a binomial function error.
I tried using both arguments as longs and both as ints and got the right binomial function error
```
torch.binomial(torch.tensor([10]).long(), torch.tensor([0.5]).long())
NotImplementedError: "binomial_cpu" not implemented for 'Long'
```
```
torch.binomial(torch.tensor([10.0]).int(), torch.tensor([0.5]).int())
NotImplementedError: "binomial_cpu" not implemented for 'Int'
```
But when I have both as different datatypes, the TensorIterator.cpp error comes back trying to align the datatypes.
`RuntimeError: Found dtype Float but expected Long`
I then tried finding where the NotImplementation Error was documented and found it in **pytorch/aten/src/ATen/Dispatch.h** in lines 193 - 211
```
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
constexpr const char* at_dispatch_name = NAME; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK_NOT_IMPLEMENTED( \
false, \
'"', \
at_dispatch_name, \
"\" not implemented for '", \
toString(_st), \
"'"); \
} \
}()
```
In the **AT_DISPATCH_SWITCH** function, it picks a tensor and its datatype and checks if the Tensor datatype matches the supported datatypes. If not we get the Not Implemented error. Unfortunately, I think the **AT_DISPATCH_SWITCH** function, uses the `common_dtype` from TensorIterator in order to run. So TensorIterator.cpp needs to happen before the AT_DISPATCH_SWITCH function.
### Summary: We are getting the wrong error message because **TensorIterator.cpp** gets called and errors out due to Tensor datatype mismatch before we can get the right error message in **Dispatch.h** for torch.binomial not supporting that datatype.
### Options for the Fix
**Option 1**: Make the error message in TensorIterator.cpp more general so it applies to torch.binomial. An error message along the lines
`RunTime Error : "Tensor Datatypes", op.target_dtype," and ", common_dtype_, "are different "`
**Option 2**: Add an error message for the binomial function datatype mismatch before the the TensorIterator.cpp error message gets called.
Although Option 1 seemed easier I think Option 2 might be better as it is more specific to the binomial function while Option1 would affect all Tensors with datatype mismatch.
**This PR applies the fix for Option 2**
After Fix :
```
torch.binomial(torch.tensor([10]).long(), torch.tensor([0.5]))
RuntimeError: Binomial function arguments count and prob must have same datatype of type Float, got: count = Long, prob = Float
```
```
torch.binomial(torch.tensor([10]).long(), torch.tensor([0.5]).long())
NotImplementedError: "binomial_cpu" not implemented for 'Long'
```
@malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157658
Approved by: https://github.com/soulitzer
Previous to this PR, torch.onnx.export(..., dynamo=True, veriy=True, report=True) does not support symbolic arguments. Such examples are like follwing:
```python
class M(torch.nn.Module):
def forward(self, a, x):
return a + torch.tensor(1) + x
op = torch.onnx.export(M(), (1, torch.ones(2)),
dynamic_shapes=(torch.export.Dim.DYNAMIC, {0: torch.export.Dim.DYNAMIC}),
dynamo=True, report=True)
```
symbolic arguments are like constant arguments that they don't have tensor_meta wither. Besides, torch.export.export supports model inputs having constants, which is different from the legacy issue: https://github.com/pytorch/pytorch/issues/99534 where we tried to get the FX directly from dynamo export. Thus, `_remove_non_tensor` is deleted from args processing.
NOTE: If the ConstantArugment shows up in exported_program, it was kept to align the length of inputs to nn.Module, but it's irrelevant to the model graph, hwich is why in ONNX model the input is omitted.
The test `test_constant_argument_user_input_is_omitted_in_onnx_graph` needs #157719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157734
Approved by: https://github.com/justinchuby
Dynamo was aggressively specializing on lazy VTs over `set_name_hint` in
`STORE_FAST`, etc., and `isinstance` in `LOAD_FAST_CHECK`. This causes
regional `torch.compile` from optimizing ComfyUI GGUF + LoRA to either
(1). exceed the recompialtion limit of 8, which results in suboptimal
performance, and (2). even if recompilation limit is increased, the
compilation time gets unnecessarily high (180s v.s. 20s for Flux).
This patch fixes the recompilation issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156891
Approved by: https://github.com/williamwen42, https://github.com/mlazos
Summary:
Change AOTI_RUNTIME_DEVICE_CHECK to the following depending on device:
AOTI_RUNTIME_CUDA_CHECK
AOTI_RUNTIME_XPU_CHECK
AOTI_RUNTIME_CPU_CHECK
Currently in the codebase, only `AOTI_RUNTIME_CUDA_CHECK` is used.
This shouldn't change anything as of now, but we do this to prepare for simultaneouly loading multiple backends (e..g CPU and CUDA) in AOTI standalone.
We don't want people writing `AOTI_RUNTIME_DEVICE_CHECK` for both CPU and CUDA checks. This could cause compilation problems when we statically link both CPU and CUDA models.
Test Plan:
CI
Rollback Plan:
Reviewed By: muchulee8
Differential Revision: D77742977
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157818
Approved by: https://github.com/jingsh
## Issue being addressed
`is_sparse` presents itself as determining if a tensor is sparse. HOWEVER, it only does checks against the tensor for `sparse_coo`. This has lead to confusion from developers as when non-coo sparse tensors are provided it return false, despite those tensors being sparse.
## Considered Remedy
Fixing this is do-able however would result in complexity as existing systems may depend on this behavior remaining consistent, and even inside of pytorch is_sparse is used by `bform` which states that it supports only `sparse_csr and sparse_coo` meaning additional work/thought would have to go into solving for `sparse_csc` and `sparse_bsr`
## Remedy provided in this PR
In lieu of these complications the lowest risk highest gain action was to add clear warning messaging to the function for now to avoid confusion to developers utilizing the function. The rest of the function behavior remains identical
## Issue content
Addresses issue number: #101385
Original issue: https://github.com/pytorch/pytorch/issues/101385
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157378
Approved by: https://github.com/soulitzer
Summary: We added the ability to make Annotating Global or Local based on an input flag in PyTorch but didn't add the args to the linter
Reviewed By: mzzchy
Differential Revision: D77959409
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157858
Approved by: https://github.com/mzzchy
Add env var AOT_INDUCTOR_ENABLE_LTO to enable clang's ThinLTO by setting AOT_INDUCTOR_ENABLE_LTO=1. The LTO is disabled by default because it may increase the build time.
Rollback Plan:
Differential Revision: D77899195
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157773
Approved by: https://github.com/desertfire
This is to unblock "dp2ep" Expert Parallel + TP integration in torchtitan https://github.com/pytorch/torchtitan/pull/1324.
It does two things:
1. Slightly modifies the glue code for FSDP/HSDP + TP to work with FSDP/HSDP + EP and FSDP/HSDP + EP + TP. I kept the name `FSDPParam._tp_spec` to make the change minimal. We can consider renaming it in the future if it confuses people, but I heard @wanchaol has a plan to rewrite DTensor strided sharding entirely.
2. Lifts the check of `_validate_tp_mesh_dim` for `torch.distributed.tensor.parallel.parallelize_module`, as in EP or EP+TP this check is too strict. In particular it assumes a DeviceMesh must have `mesh_dim_names` which is not always true. I'm also removing the file `torch/distributed/tensor/parallel/_utils.py` it belongs entirely, as the other check `_deprecate_warnings`, added two years ago, is not used any more.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157216
Approved by: https://github.com/wanchaol, https://github.com/weifengpy
I was confused about why the distributed tests weren't showing up quickly on HUD, its because the call of run_tests.py for distributed didn't include upload artifacts while running flag, so set it to default to IS_CI so I don't need to put the flag everywhere
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157868
Approved by: https://github.com/huydhn