This PR adds a new interface _aot_compile to `OptimizedModule`, so that the following is possible:
```
mod = SimpleLinearModule()
inputs = [
ModelInput(
args=(torch.randn(3, 3),),
kwargs={},
contexts=[torch.no_grad(), eval_mode(model)],
),
ModelInput(
args=(torch.randn(3, 3),), kwargs={}, contexts=[train_mode(model)]
),
]
assert isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
model._aot_compile(
inputs,
)
```
After this PR, you can AOT precompile NanoGPT and use it to train directly. I'll share my fork of the repo to make this work.
## ModelInput
The `ModelInput` API is a work in progress; for now it represents a set of inputs and contexts to instruct the compiler to compile. Most commonly, this is "compile an eval mode with no grad, and a training mode with grad", but also contains things like autocasting contexts, etc.
## Dispatch
Dispatching is super simple here, we just iterate through all the precompiled fullgraphs and check guards for each one until there's one htat passes. I'm a bit worried that having this in python code is going to be too expensive. The guard checks are happening in C++ anyway, though, so the only python bottlenecked step here is just the for loop, so perhaps the overhead will not be high. I'll work on measuring this, though.
## TODOs
This PR does not support `mod.compile()`, only `torch.compile(mod)`. In order to support `mod.compile()`, we'll need to update torch.nn.Module with an updated implementation — I can add that frontend later.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162171
Approved by: https://github.com/zhxchen17
Summary:
The check is introduced in D82262053
- `scalar_value` could be a numpy object
- Move the check of `device.type` into `make_np` method where it happens only when it's a `torch.Tensor`.
Test Plan:
```
vizard launch -j 1x8 --launch=flow --config-path=pkg://vizard_projects.image_classification.configs --config-name=resnet50 ++flow.secure_group=ml_sensors ++flow.entitlement=ai_frameworks_pnb ++max_train_steps_per_epoch=10 ++max_epochs=5 ++log_every_n_steps=10 ++profiler=null ++max_eval_steps_per_epoch=10
```
Rollback Plan:
Differential Revision: D82383428
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162888
Approved by: https://github.com/xush6528
Per NVRTC doc - https://docs.nvidia.com/cuda/nvrtc/index.html#accessing-lowered-names, we can compile a templated kernel (e.g. `kernel<float>`) with the following steps
NVRTC side
- (new) `nvrtcAddNameExpression` -> C++ template e.g. `f<float>`
- `nvrtcCompileProgram`
- (new) `nvrtcGetLoweredName` -> get mangled name. need to do a copy since later this string is freed after NVRTC program is destroyed
- `nvrtcDestroyProgram`
CUDA side
- use mangled name instead of normal name -> profit
- `extern "C"` is not even needed
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162875
Approved by: https://github.com/msaroufim
Summary:
Adds support for TMA store in all TMA matmul templates (notably persistent_tma including addmm and scaled_mm). This works by requiring a template be registered with `tma_store=True` and when met constructs indices/range_trees to hook into the existing code base's TMA store support.
This also includes a couple notable changes:
- Adds support in the TMA template support for checking the output layout.
- Adds support for "hoisting" the tensor descriptor to the top of the kernel. This will currently only be used by template code right now, but in principle it can be generalized to other implementation.
- Supports considering multiple indices as the "contiguous" index. This is handled with support for transposing the input data when the alignment is no longer consistent. In general since the TMA support is derived from the index it doesn't seems reasonable that the 1D index math forces a certain alignment depending on index ordering so long as the layout matches.
Test Plan:
Tested with test_max_autotune.py unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160480
Approved by: https://github.com/NikhilAPatel
Summary:
When rewriting sympy expressions in the compiler codebase we want to generate
FloorDiv(a, b) CleanDiv(a, b) directly and not a//b. since the later become floor(a*pow(b, -1))
For symnodes we automatically handle that conversions in the symnode op dispatch.
I will follow up with an issue to track all other usages of //.
Block internal Model.
Test Plan:
add test
run existing tests.
dakechen1993 testing on the model.
Rollback Plan:
Differential Revision: D82362241
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162869
Approved by: https://github.com/ezyang
Summary:
Please see D21021645 for details about the optimization and why it's beneficial.
A similar change has been added to libstdc++ as well, see dbf8bd3c2f
Rollback Plan:
Reviewed By: yfeldblum
Differential Revision: D81960754
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162784
Approved by: https://github.com/swolchok
Summary:
One of our builds fails because the return value of fread is discarded. Explicit cast to void fixes the build.
```log
In file included from fbcode/caffe2/torch/csrc/jit/mobile/import.cpp:15:
fbcode/caffe2/torch/csrc/jit/mobile/file_format.h:156:3: error: ignoring return value of function declared with 'warn_unused_result' attribute [-Werror,-Wunused-result]
156 | fread(data.get(), size, 1, f);
| ^~~~~ ~~~~~~~~~~~~~~~~~~~~~~
1 error generated.
...
BUILD FAILED
Failed to build 'fbcode//caffe2:libtorch (cfg:opt-linux-x86_64-clang19-no-san-opt-by-default#fef256f7ee896871)'
```
Test Plan:
No runtime behavior change. CI.
Rollback Plan:
Differential Revision: D82265002
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162767
Approved by: https://github.com/Skylion007
Summary:
The SIMD path is using SLEEF version of `pow` which is slightly different from `std::pow`. The fix is to use the same vectorized code (with partial load and store) for the trailing data as well to ensure consistency between results.
Rollback Plan:
Differential Revision: D82265247
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162772
Approved by: https://github.com/swolchok
Investigated together with @pyemma and @taotaohuang001
## Problem
when calling exported module with dict nested in the args tuple, it will make following complaits
```
Traceback (most recent call last):
File "/home/chzhu/infinitrain/test_torch_export.py", line 32, in <module>
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 848, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 424, in __call__
raise e
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/fx/graph_module.py", line 411, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1806, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 929, in _fn
return fn(*args, **kwargs)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 81, in _check_input_constraints_pre_hook
flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
File "/home/chzhu/infinitrain/build/infinitrain/environments/development-venv/lib/python3.10/site-packages/torch/export/_unlift.py", line 64, in _check_inputs_match
raise ValueError( # noqa: B904
ValueError: Trying to flatten user inputs with exported input tree spec:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a1', 'a2'], [*,
*])]),
TreeSpec(dict, [], [])])
but actually got inputs with tree spec of:
TreeSpec(tuple, None, [TreeSpec(tuple, None, [TreeSpec(dict, ['a2', 'a1'], [*,
*])]),
TreeSpec(dict, [], [])]).
Please check that the inputs have the same number and type of args and kwargs as the ones you used when tracing.
```
## How to reproduce the issue
```python
import torch
# create a nn.Module with data_batch as input and output as output
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, data_batch):
h1 = self.linear(data_batch["a1"])
h2 = self.linear(data_batch["a2"])
return h1 + h2
# torch export this module
model = MyModel()
example_args_forward = (
{
"a1": torch.randn(10),
"a2": torch.randn(10),
},
)
exported_model = torch.export.export(model, example_args_forward, strict=True)
# save the exported model
torch.export.save(exported_model, "exported_model.pt2")
# load the exported model
exported_model = torch.export.load("exported_model.pt2").module()
# run the exported model
print(exported_model({"a2": torch.randn(10), "a1": torch.randn(10)}))
```
## Root Cause
Input spec is encoded as [TreeSpec](582d278983/torch/utils/_pytree.py (L1059)) in torch export. With (args, kwargs) at the top level. When we call the exported model, it has a pre-execution [hook](582d278983/torch/export/_unlift.py (L66)) to check the input TreeSpec matches the received TreeSpec, where in Treespec, the dict key order is preserved. Something like
TreeSpec(dict, ['a2', 'a1'], [*,*])
To workaround this, the input check reorders [kwargs](582d278983/torch/export/_unlift.py (L67)), that is why kwargs can be out of order. But the dict nested in the args is not re-ordered, so any re-ordering of the keys will throw errors.
## Solution
Update eq_spec to handle the dict case, where we only guarantee that key set is the same without ordering constraints.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162618
Approved by: https://github.com/angelayi
Summary:
To support exporting a cuda model on a CPU-only machine under fake tensor mode.
User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call.
This diff supports this.
I expect the following pattern to work
```
with FakeTensorMode(allow_non_fake_inputs=True):
cuda_module = module.to("cuda:0")
cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs])
with torch.no_grad():
ep = torch.export.export(cuda_module, cuda_sample_inputs)
```
Test Plan:
CI
Rollback Plan:
Differential Revision: D80181887
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160532
Approved by: https://github.com/henryoier, https://github.com/ezyang
Summary:
Sometimes checkpoint background process creation times out during gloo pg init.
Attempting to destroy the process during that time can block the trainer thread until the timeout completes.
This diff reduces the pg init timeout from 30m -> 10m to reduce the cleanup time.
Test Plan:
CI
Rollback Plan:
Differential Revision: D81724668
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162760
Approved by: https://github.com/meetv18
Summary: Updates the NoValidChoicesError logic to include some additional context for if not choices exists or if no choices compiled.
Test Plan:
NFC. Depending on CI.
Rollback Plan:
Differential Revision: D82312035
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162814
Approved by: https://github.com/mlazos
Currently, OpenReg supports Linux, Windows, and OS X, ensuring stability and ease of integration with third-party devices across all three platforms. It also doesn't rely on any other accelerators (such as CUDA or MPS).
Therefore, to minimize computational resource usage, `test_openreg` can be added to certain BLOCKLISTS to prevent its execution, limiting OpenReg's execution to only necessary scenarios.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161918
Approved by: https://github.com/albanD
ghstack dependencies: #161917
**Background:**
Almost all the tests in `test/test_openreg.py` are designed for `torch_openreg`, so placing these testcases in the test directory is not a good idea. Instead, they should be moved to the `tests` directory under `torch_openreg`, coordinating these tests with their corresponding functional logic.
**How to do:**
So how do we verify the quality of the third-party device integration mechanism?
We will maintain a `test_openreg` entrypoint in `test/run_test.py`.
This entrypoint will install `torch_openreg` and run all the testcases located in `torch_openreg`. As long as all testcases pass, we can guarantee that the out-of-tree backend integration mechanism is available.
**Next:**
We will also improve `torch_openreg's` test coverage in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161917
Approved by: https://github.com/albanD
Minor bug fix for the `nll_loss` test.
Before this PR, it runs `torch.randint(high=0)`, which will fail because it would try to generate a number that >= low and < high, i.e. x>=0 and x<0.
The test did not fail because that line is not run when testing on CPU because it failed earlier because of a unsupported dtype.
However, as we support TPUs at Google, this line is reached first before the dtype check, which triggers the bug.
To my understanding, these OpInfo should be general enough to support different hardware.
Fixing this obvious bug would make it more general cross different hardware.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162763
Approved by: https://github.com/soulitzer
# why
- now everything is in place to just gather templates and run
the V.choices.get_mm_configs once per op
- enables any overrides inside V.choices.get_mm_configs to
have a full view of the options for an op, not just for
one template
# what
- replace multiple calls to V.choices.get_mm_configs with
calls to gather the active templates, and then using those
in a single call
# testing
```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```
Differential Revision: [D81520571](https://our.internmc.facebook.com/intern/diff/D81520571)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161350
Approved by: https://github.com/eellison, https://github.com/jansel
ghstack dependencies: #161351
# why
- if we only use ExternKernelChoice we're not doing any codegen
- if we're not doing any codegen, we can use a FlexibleLayout
here, and provide deeper passes more chances to change it
# what
- if all the kernel template choices (KTC) are with a ExternKernelChoice
template, we switch to a FlexibleLayout before generating the choice
- add a test to make sure that works as intended (FlexibleLayout for
only extern, and FixedLayout if Triton is involved)
- caveats:
- because CPP, CUTLASS, and CK are not using
V.choices.get_mm_configs yet, we turn off the optimization
if either of those backends are in use. This will be relaxed
once they support this too
- because Triton templates are still using their own calls
(not a single call) to get_mm_configs, it's also turned
off there. The next diff unifies Triton + ATEN to a single
call to get_mm_configs and that in turn allows the optimization
there too
# testing
```
python3 -bb -m pytest test/inductor/test_max_autotune.py -v
```
Differential Revision: [D81520584](https://our.internmc.facebook.com/intern/diff/D81520584)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161351
Approved by: https://github.com/eellison, https://github.com/jansel