Compare commits

...

17 Commits

Author SHA1 Message Date
79349d6d8c Update
[ghstack-poisoned]
2025-11-05 15:31:48 -08:00
a91772299a Update (base update)
[ghstack-poisoned]
2025-11-05 15:31:48 -08:00
08200280ce [CP][BE][3/N] Add _templated_ring_attention to the backward compatility stub (#166991)
While `_templated_ring_attention` is a private API, it is unfortunatelly used by some packages.
Add it to __all__ so that people can still use it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166991
Approved by: https://github.com/XilunWu
ghstack dependencies: #166456, #166501
2025-11-05 22:22:55 +00:00
ad7a57262c [12/N] Apply ruff UP035 rule (#166929)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929
Approved by: https://github.com/Lucaskabela
2025-11-05 22:06:19 +00:00
711a775878 fix nccl estimations (#167093)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167093
Approved by: https://github.com/kwen2501, https://github.com/eellison
2025-11-05 22:01:49 +00:00
e9a688f02e [DebugMode] output, tensor id annotations for DebugMode (#165076)
Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)`

Example output for `test_debug_mode_mm`, with both enabled:
```
  torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0))  ->  dt$12: f32[8, 32]| S(0)
    aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0))
      redistribute_input(1, S(0) -> R)
        redistribute_input(t$4: f32[1, 32], trace: S(0)->R)
          _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0)  ->  t$6: f32[8, 32]
          _c10d_functional::wait_tensor(t$7: f32[8, 32])  ->  t$8: f32[8, 32]
      aten::mm(t$9: f32[1, 8], t$10: f32[8, 32])  ->  t$11: f32[1, 32]
  <method 'sum' of 'torch._C.TensorBase' objects>(dt$13: f32[8, 32]| S(0))  ->  dt$17: f32[]| P
    aten::sum(dt$14: f32[8, 32]| S(0))
      aten::sum(t$15: f32[1, 32])  ->  t$16: f32[]"""
```

Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076
Approved by: https://github.com/zpcore
2025-11-05 22:00:11 +00:00
e69aaaf45a [user-streams] Add backward test (#167021)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167021
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #167019
2025-11-05 21:24:44 +00:00
fd8f368d31 [user-streams] Add graph annotation checks (#167019)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167019
Approved by: https://github.com/Lucaskabela
2025-11-05 21:24:44 +00:00
13d2cc7bd2 Remove python workaround for ContextDecorator (#167049)
This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049
Approved by: https://github.com/Skylion007
2025-11-05 20:56:04 +00:00
c6c913d18e Add torch::stable::Tensor sizes and strides (#165153)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165153
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #164991, #165152
2025-11-05 20:55:34 +00:00
ef3f953966 Revert "[DebugMode] output, tensor id annotations for DebugMode (#165076)"
This reverts commit a64c7d740428010d700b4bcd395af8a7b2d5c21f.

Reverted https://github.com/pytorch/pytorch/pull/165076 on behalf of https://github.com/wdvr due to Sorry but this is breaking internally. See diff [D86245252](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fdiff%2FD86245252&h=AT1oPbS1XTv6HjYeYdxmDMW1-jlT0pS8yBO2iSfbPfUB9ydsEjFXBNT56QhV1v5TKc4_QaQNxykNowSKmb4fgenjOyCv20NuL7oV_Id5fhh32hhv1IpjgsDJYK-PBFfSfv_miLIWfNgj902KcgXojbBgDcDzQeS9lNt0GQ) for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/165076#issuecomment-3493358159))
2025-11-05 20:52:43 +00:00
ea44f12bce [13/N] Apply ruff UP035 rule (#167048)
This PR continues to apply ruff UP035 rule to test code and some remaining torch files.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048
Approved by: https://github.com/Skylion007
2025-11-05 20:51:53 +00:00
a74fe75c45 Don't hardcode double argument for reduction base (#166951)
Fixes https://github.com/pytorch/pytorch/issues/43254

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166951
Approved by: https://github.com/ngimel, https://github.com/Skylion007
ghstack dependencies: #166813
2025-11-05 20:34:15 +00:00
6d30666bc1 Revert "[12/N] Apply ruff UP035 rule (#166929)"
This reverts commit 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52.

Reverted https://github.com/pytorch/pytorch/pull/166929 on behalf of https://github.com/donigian due to Temporarily need to revert this to continue a revert for #165076. @cyyever Please re-merge after revert of #165076. ([comment](https://github.com/pytorch/pytorch/pull/166929#issuecomment-3493090596))
2025-11-05 20:02:47 +00:00
8e8cbb85ee Revert "[Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)"
This reverts commit 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23.

Reverted https://github.com/pytorch/pytorch/pull/166890 on behalf of https://github.com/malfet due to Looks like it broke torchfuzz tests, see fbd70fb84e/1 and same test on slow ([comment](https://github.com/pytorch/pytorch/pull/166890#issuecomment-3493011038))
2025-11-05 19:42:39 +00:00
fbd70fb84e Update typing docs to reference pyrefly (#166883)
Replacing mypy codumentation in the CONTRIBUTING.MD file with pyrefly references. I have made initial changes to https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch documentation, and will replace the script at the bottom with one tailored to the pyrefly tool as a follow-up.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166883
Approved by: https://github.com/malfet
2025-11-05 19:35:38 +00:00
6c5db82584 [Inductor] Naive foreach autotune support (#162053)
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
2025-11-05 19:27:23 +00:00
27 changed files with 511 additions and 164 deletions

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `mypy`](#running-mypy)
- [Running `pyrefly`](#running-pyrefly)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `mypy` - recommended for linting
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,15 +350,32 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `mypy`
#### Running `pyrefly`
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for more information on how to set up `mypy` and tackle type annotation
tasks.
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
### C++ Unit Testing

View File

@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) {
});
}
template <typename func_t, typename vec_func_t>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) {
template <typename func_t, typename vec_func_t, typename ident_t = double>
void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast<ident_t>(0)) {
using traits = binary_function_traits<func_t>;
static_assert(
all_same<

View File

@ -339,33 +339,13 @@ void or_kernel_impl(TensorIterator& iter) {
}
}
template<typename scalar_t>
struct MinValuesOps: public at::native::MinOps<scalar_t> {
using arg_t = typename MinOps<scalar_t>::arg_t;
static scalar_t project(arg_t arg) {
return arg.first;
}
};
void min_values_kernel_impl(TensorIterator& iter) {
// This case is special because of Vectorized<int64_t> does not
// handle upper_bound<int64_t>().
// See: https://github.com/pytorch/pytorch/issues/43254
if (iter.dtype() == kLong || iter.dtype() == kUInt64) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce(
iter,
MinValuesOps<scalar_t>{},
std::pair<scalar_t, int64_t>(upper_bound<scalar_t>(), -1));
}), kLong, kUInt64);
return;
}
AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] {
binary_kernel_reduce_vec(
iter,
[](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); },
[](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { return minimum(a, b); },
static_cast<double>(upper_bound<scalar_t>()));
upper_bound<scalar_t>());
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool);
}

View File

@ -47,20 +47,10 @@ Tensor sgd_out_of_place(
STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1");
STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1");
int64_t *param_sizes;
int64_t *param_strides;
aoti_torch_get_sizes(param.get(), &param_sizes);
aoti_torch_get_strides(param.get(), &param_strides);
// testing Tensor strides + stride
STD_TORCH_CHECK(param.strides()[0] == param.stride(0));
int32_t param_dtype;
aoti_torch_get_dtype(param.get(), &param_dtype);
int32_t param_device_type;
aoti_torch_get_device_type(param.get(), &param_device_type);
AtenTensorHandle out_ath;
aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath);
auto out = Tensor(out_ath);
auto out = new_empty(param, param.sizes());
sgd_math(
reinterpret_cast<float*>(param.data_ptr()),
@ -344,6 +334,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
// Still using a std::vector below even though people can just pass in an
// initializer list (which will be implicitly converted to an HeaderOnlyArrayRef)
// directly.
// This is to test that passing in a std::vector works for BC. (It gets
// implicitly converted to HeaderOnlyArrayRef too!)
std::vector<int64_t> sizes = {2, 5};
auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16);
return new_empty(t, sizes, dtype);

View File

@ -8,21 +8,11 @@ from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
@ -40,7 +30,7 @@ class GraphDededuplicationTests(TestCase):
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
return extract_graph(fn, *args, **kwargs)
return extract_graph(fn, *args, **kwargs)[0:3]
def run_and_get_simple_graph(self):
def fn(x, y):

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Sequence
from typing import Any, Callable, Union
from collections.abc import Callable, Sequence
from typing import Any, Union
import torch
import torch._dynamo

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import Callable, NamedTuple, Optional
from typing import NamedTuple, Optional, TYPE_CHECKING
import torch
import torch._dynamo
@ -7,6 +7,10 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -1,11 +1,17 @@
# Owner(s): ["module: dynamo"]
import functools
import re
import unittest
import weakref
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.graph_bytecode_inputs import (
reset_user_object_tracking,
store_user_object_weakrefs,
)
from torch._dynamo.testing import extract_graph, remove_trailing_space
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import requires_cuda
@ -15,6 +21,14 @@ requires_multigpu = functools.partial(
)
def remove_file_comment(gm_str: str) -> str:
return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str))
def print_graph(graph: torch.fx.GraphModule) -> str:
return remove_file_comment(graph.print_readable())
class TestStreams(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
@ -36,9 +50,7 @@ class TestStreams(torch._dynamo.test_case.TestCase):
@requires_cuda
def test_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
def fn(x, y, s1, s2):
with s1:
z1 = torch.add(x, y)
with s2:
@ -47,13 +59,36 @@ class TestStreams(torch._dynamo.test_case.TestCase):
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream())
expected = fn(*inp)
fn_opt = torch.compile(fn, fullgraph=True)
actual = fn_opt(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
@requires_cuda
@unittest.skip("Needs graph break support with annotation context")
def test_stream_context_graph_break(self):
def fn(x, y):
s2 = torch.Stream()
@ -70,9 +105,16 @@ class TestStreams(torch._dynamo.test_case.TestCase):
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
fn_opt = torch.compile(fn)
actual = fn_opt(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(expected, actual)
self.assertEqual(len(fw_graphs), 2)
self.assertExpectedInline(print_graph(fw_graphs[0]), """""")
self.assertExpectedInline(print_graph(fw_graphs[1]), """""")
@requires_cuda
def test_stream_input(self):
@ -155,35 +197,310 @@ class TestStreams(torch._dynamo.test_case.TestCase):
self.assertEqual(s_act, s_exp)
def test_nested_stream_enter_exit(self):
pass
def fn(x, y, s0, s1, s2):
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (
torch.ones(2, 2) + 1,
torch.ones(2, 2),
torch.Stream(),
torch.Stream(),
torch.Stream(),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
@unittest.skip("Needs graph break support with annotation context")
def test_stream_enter_exit_graph_break(self):
pass
@unittest.skip("Needs graph break support with annotation context")
def test_nested_stream_enter_exit_graph_break(self):
pass
def test_local_stream_enter_exit(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
with s1:
z1 = torch.add(x, y)
with s2:
z = torch.add(x, y)
y = z + 2 + z1
return y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
""",
)
def test_local_stream_nested_enter_exit(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
z1 = torch.add(x, y)
with s0:
z0 = torch.add(x, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': 0}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
)
def test_stream_with_mutation(self):
pass
def fn(x, y):
s2 = torch.Stream()
s1 = torch.Stream()
s0 = torch.Stream()
with s1:
with s2:
x.add_(y)
with s0:
z1 = torch.add(y, y)
z0 = torch.add(z1, y)
with s2:
y = 2 + z1
return z0, y
inp = (torch.ones(2, 2) + 1, torch.ones(2, 2))
expected = fn(*inp)
(
actual,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1)
# Annotation: {'stream': 2}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None
# Annotation: {'stream': 0}
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
#
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None
return (add_2, add_3)
""",
)
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, requires_grad=True) + 1,
torch.ones(2, 2, requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
#
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
@requires_cuda
def test_run_opcheck(self):
def test_run_opcheck_fork_join(self):
from torch._dynamo.variables.streams import fork_stream, join_stream
from torch.library import opcheck
sample_inputs = [
(0, torch.device("cuda:0"), 1, torch.device("cuda:1")),
(2, torch.device("cuda:2"), 3, torch.device("cuda:1")),
]
for args in sample_inputs:
opcheck(fork_stream, args)
opcheck(join_stream, args)
original_stream = torch.accelerator.current_stream()
try:
s0 = torch.Stream()
s1 = torch.Stream()
store_user_object_weakrefs(s0, s1)
sample_inputs = [
(0, 1),
(1, 0),
]
for args in sample_inputs:
opcheck(fork_stream, args)
opcheck(join_stream, args)
finally:
torch.accelerator.set_stream(original_stream)
reset_user_object_tracking()
@requires_cuda
def test_run_opcheck_wait_record(self):
from torch._dynamo.variables.streams import record_event, wait_event
from torch.library import opcheck
original_stream = torch.accelerator.current_stream()
try:
s0 = torch.Stream()
s1 = torch.Stream()
e0 = torch.Event()
e1 = torch.Event()
store_user_object_weakrefs(s0, s1, e0, e1)
sample_inputs = [
(2, 0),
(3, 1),
]
for args in sample_inputs:
opcheck(wait_event, args)
opcheck(record_event, args)
finally:
torch.accelerator.set_stream(original_stream)
reset_user_object_tracking()
def test_is_marked_side_effectful(self):
self.assertIn(
torch.ops.streams.fork.default, torch.fx.node._side_effectful_functions
)
self.assertIn(
torch.ops.streams.join.default, torch.fx.node._side_effectful_functions
)
self.assertIn(
torch.ops.streams.wait_event.default,
torch.fx.node._side_effectful_functions,
)
self.assertIn(
torch.ops.streams.record_event.default,
torch.fx.node._side_effectful_functions,
)
if __name__ == "__main__":

View File

@ -14424,20 +14424,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),))
@skip_if_halide
@requires_cuda_and_triton
def test_unbacked_float_item(self):
def fn(x, max_val):
return torch.clamp(x, 0, max_val.item())
self.common(
fn,
(
torch.randn(10, 20, 30, device=self.device),
torch.tensor(5.0, device=self.device),
),
)
# end of class CommonTemplate - add new tests here

View File

@ -1,5 +1,5 @@
from typing import Union
from typing_extensions import assert_type, TypeAlias
from typing import TypeAlias, Union
from typing_extensions import assert_type
from torch import randn, Tensor

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
# mypy: disable-error-code="type-arg"
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from typing import Any, Callable, Optional, overload, Union
from typing import Any, Optional, overload, Union
import torch
from torch import Tensor

View File

@ -87,6 +87,12 @@ def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-d
return gm.graph, region_tracker # type: ignore[union-attr]
def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def]
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs, backend.bw_graphs
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> list[Any]:

View File

@ -21,9 +21,9 @@ restoring state changes.
import inspect
import sys
import warnings
from collections.abc import Callable, Sequence
from collections.abc import Callable, Sequence, Sized
from contextlib import ExitStack
from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union
from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union
import torch._C
from torch._guards import Guard

View File

@ -5,11 +5,12 @@ from typing import Any, Optional
import torch
from torch._dynamo.variables.dicts import ConstDictVariable
from torch._dynamo.variables.lists import TupleVariable
from torch.fx import Proxy
from torch.fx import has_side_effect, Proxy
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..exc import TYPE_CHECKING, unimplemented_v2
from ..graph_bytecode_inputs import get_external_object_by_index
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import FxTracebackAnnotateVariable
@ -27,46 +28,93 @@ from torch._library.custom_ops import custom_op
Tensor = torch.Tensor
def _get_stream_by_index(index: int) -> torch.Stream:
stream = get_external_object_by_index(index)
assert isinstance(stream, torch.Stream), (
f"Fork/join stream expected a stream object at index {index}"
)
return stream
def _get_event_by_index(index: int) -> torch.Event:
event = get_external_object_by_index(index)
assert isinstance(event, torch.Event), (
f"Record/wait event expected an event object at index {index}"
)
return event
@custom_op("streams::fork", mutates_args=())
def fork_stream(
from_index: int,
from_device: torch.device,
from_index: int, # kept to make stream transitions clearer
to_index: int,
to_device: torch.device,
) -> None:
pass
torch.accelerator.set_stream(_get_stream_by_index(to_index))
@fork_stream.register_fake
def _(
from_index: int,
from_device: torch.device,
from_index: int, # kept to make stream transitions clearer
to_index: int,
to_device: torch.device,
) -> None:
pass
has_side_effect(torch.ops.streams.fork.default)
@custom_op("streams::join", mutates_args=())
def join_stream(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> None:
pass
def join_stream(from_index: int, to_index: int) -> None:
torch.accelerator.set_stream(_get_stream_by_index(to_index))
@join_stream.register_fake
def _(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> None:
pass
has_side_effect(torch.ops.streams.join.default)
@custom_op("streams::record_event", mutates_args=())
def record_event(event_index: int, stream_index: int) -> None:
event = _get_event_by_index(event_index)
stream = _get_stream_by_index(stream_index)
stream.record_event(event)
@record_event.register_fake
def _(
event_index: int,
stream_index: int,
) -> None:
pass
has_side_effect(torch.ops.streams.record_event.default)
@custom_op("streams::wait_event", mutates_args=())
def wait_event(event_index: int, stream_index: int) -> None:
event = _get_event_by_index(event_index)
stream = _get_stream_by_index(stream_index)
stream.wait_event(event)
@wait_event.register_fake
def _(
event_index: int,
stream_index: int,
) -> None:
pass
has_side_effect(torch.ops.streams.wait_event.default)
class SymbolicStreamState:
"""Track the currently entered stream if any"""

View File

@ -2970,12 +2970,6 @@ class CppPythonBindingsCodeCache(CppCodeCache):
throw std::runtime_error("expected int arg");
return reinterpret_cast<uintptr_t>(result);
}}
template <> inline float parse_arg<float>(PyObject* args, size_t n) {{
auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n));
if(unlikely(result == -1.0 && PyErr_Occurred()))
throw std::runtime_error("expected float arg");
return static_cast<float>(result);
}}
{extra_parse_arg}

View File

@ -1732,15 +1732,9 @@ class KernelArgs:
call_args.append(self.wrap_ptr_arg(outer, dtype))
arg_types.append(f"{cpp_dtype}*")
for outer, inner in self.sizevars.items():
if isinstance(outer, sympy.Symbol) and symbol_is_type(
outer, (SymT.UNBACKED_FLOAT)
):
arg_defs.append(f"const float {inner}")
arg_types.append("const float")
else:
arg_defs.append(f"const {INDEX_TYPE} {inner}")
arg_types.append(f"const {INDEX_TYPE}")
arg_defs.append(f"const {INDEX_TYPE} {inner}")
call_args.append(self.wrap_size_arg(outer))
arg_types.append(f"const {INDEX_TYPE}")
if V.graph.wrapper_code:
V.graph.wrapper_code.ensure_size_computed(outer)
assert not self.workspace_args, "Workspace not supported on CPU "
@ -2359,7 +2353,6 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
SymT.UNBACKED_INT,
SymT.SIZE,
SymT.PRECOMPUTED_SIZE,
SymT.UNBACKED_FLOAT,
),
)
}

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import hashlib
from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
import sympy # noqa: TC002
@ -17,6 +17,8 @@ from .simd import SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ..ir import IRNode
from ..scheduler import BaseSchedulerNode

View File

@ -627,7 +627,7 @@ class ComboKernel(Kernel):
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
num_warps={self.num_warps},
filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)

View File

@ -4,7 +4,6 @@ from typing import Any, Optional
import sympy
import torch
from torch.utils._sympy.symbol import symbol_is_type, SymT
from .. import config
from ..runtime.hints import AttrsDescriptorWrapper
@ -72,10 +71,6 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str:
return "constexpr"
elif isinstance(arg.expr, (float, sympy.Float)):
return "fp32"
elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type(
arg.expr, (SymT.UNBACKED_FLOAT)
):
return "fp32"
elif isinstance(arg.expr, bool):
return "i1"

View File

@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node(
fx_node: torch.fx.Node,
override_size: Optional[int] = None,
# TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix.
use_nccl_estimator: bool = False,
use_nccl_estimator: bool = True,
) -> float:
"""
Returns estimated NCCL collective runtime in nanoseconds (ns).

View File

@ -1,6 +1,6 @@
import os
from collections.abc import Callable
from functools import cache, partial
from typing import Callable
import torch
from torch._environment import is_fbcode

View File

@ -3586,13 +3586,24 @@ def user_autotune(
)
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
configs = []
# Naive autotuning path for num_warps
if not (
inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
return cached_autotune(
None,
[triton.Config({}, num_stages=1, num_warps=num_warps)],
configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,

View File

@ -52,26 +52,7 @@ __all__ = [
"MemRecordsAcc",
]
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator as _ContextDecorator
except ImportError:
import functools
class _ContextDecorator: # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
with self:
return func(*args, **kwargs)
return wrapped
from contextlib import ContextDecorator
# global python state - whether profiler is currently enabled
@ -744,8 +725,7 @@ class profile:
return all_function_events
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
class record_function(ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.

View File

@ -3593,6 +3593,7 @@ float ProcessGroupNCCL::endTimeEstimate() {
#ifdef NCCL_SIM_INFO_INITIALIZER
ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER;
C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt);
--ncclActiveGroupCounter_;
return simInfo.estimatedTime;
#else
TORCH_CHECK(

View File

@ -4,6 +4,7 @@
#include <torch/headeronly/core/ScalarType.h>
#include <torch/headeronly/macros/Macros.h>
#include <torch/headeronly/util/Exception.h>
#include <torch/headeronly/util/HeaderOnlyArrayRef.h>
#include <torch/headeronly/util/shim_utils.h>
#include <climits>
#include <memory>
@ -13,6 +14,7 @@
HIDDEN_NAMESPACE_BEGIN(torch, stable)
using accelerator::DeviceIndex;
using torch::headeronly::IntHeaderOnlyArrayRef;
using torch::headeronly::ScalarType;
// The torch::stable::Tensor class is a highlevel C++ wrapper around
@ -93,6 +95,32 @@ class Tensor {
return numel;
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the dimension sizes of
// a Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef sizes() const {
int64_t* sizes;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes));
return IntHeaderOnlyArrayRef(sizes, dim());
}
// note: this API is, for all intents and purposes, the same as the one in
// TensorBase.h: it returns a borrowed reference of the strides of a
// Tensor.
//
// The only difference is that it returns a header-only IntHeaderOnlyArrayRef,
// which has slightly less functionality than a regular IntArrayRef. See
// [HeaderOnlyArrayRef vs ArrayRef note] for more details.
IntHeaderOnlyArrayRef strides() const {
int64_t* strides;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides));
return IntHeaderOnlyArrayRef(strides, dim());
}
// note: this is a subset of the original TensorBase API. It takes no
// arguments whereas the original API takes in a kwarg of memory format.
// Here, we assume the default contiguous memory format.

View File

@ -1,9 +1,8 @@
import functools
import math
import operator
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from datetime import timedelta
from typing import Callable
import torch
from torch._C import ScriptObject

View File

@ -10,6 +10,7 @@ from ._context_parallel._attention import (
_enable_context_parallel_dispatcher,
_is_causal_behavior,
_RotateMethod,
_templated_ring_attention,
context_parallel,
context_parallel_unshard,
set_rotate_method,
@ -22,6 +23,7 @@ from ._context_parallel._load_balancer import (
)
# TODO(fegin): add deprecation message once the final interfaces are concluded.
__all__ = [
"_CausalBehavior",
"_context_parallel_shard",
@ -31,6 +33,7 @@ __all__ = [
"_enable_context_parallel_dispatcher",
"_is_causal_behavior",
"_RotateMethod",
"_templated_ring_attention",
"context_parallel",
"context_parallel_unshard",
"set_rotate_method",