Compare commits

..

11 Commits

Author SHA1 Message Date
e457003553 [user-streams] Assign streams to gradient accum in bwd
ghstack-source-id: c355023c56ad417e101dec575d236d117e1645a1
Pull-Request: https://github.com/pytorch/pytorch/pull/167513
2025-11-11 01:54:37 -08:00
5f990d00e4 [user-streams] wait_stream op
ghstack-source-id: 0a5156564bb1b4077e2a2d8b39e4c7e22d94c750
Pull-Request: https://github.com/pytorch/pytorch/pull/167512
2025-11-11 01:54:36 -08:00
706d566e2c [user-streams] Allow new streams to be created and registered during compilation
ghstack-source-id: dd84afd5fd3147dc3df4b05e1348a2a651ffd093
Pull-Request: https://github.com/pytorch/pytorch/pull/167511
2025-11-11 01:54:36 -08:00
3fdee99d7c [user-streams] Allow new events to be created and registered during compilation
ghstack-source-id: 2fdbe21f2780969fdecc5ca48e537780f6f7e299
Pull-Request: https://github.com/pytorch/pytorch/pull/167510
2025-11-11 01:54:35 -08:00
daed97afff [Inductor] fix CppTile2DKernel for fp8 datatype (#167451)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167451
Approved by: https://github.com/Xia-Weiwen, https://github.com/jansel
2025-11-11 09:25:14 +00:00
53947adb1f [Inductor] optimize the heuristics of sum reduction (#163144)
Fix https://github.com/pytorch/pytorch/issues/151400.
**Summary:**
Optimize the heuristics of sum reduction, reduce the chunk size of cascade sum to improve numerical stability.
I ran the Inductor benchmark with this PR on CPU, and no performance regression is seen.

**Example:**
Take https://github.com/pytorch/pytorch/issues/151400 as an example:
```
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch._inductor import config

config.fallback_random = True
torch.set_grad_enabled(False)
torch.manual_seed(0)

class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        vec = x.flatten()
        vec_one = torch.ones_like(vec)
        x = torch.outer(vec, vec_one)
        return torch.mean(x, dim=1)

model = Model()

x = torch.randn(3, 8, 64, 64)  # error will be amplified as the input tensor gets larger

inputs = [x]

def run_test(model, inputs, backend):
    if backend != "eager":
        model = torch.compile(model, backend=backend)
    torch.manual_seed(0)
    output = model(*inputs)
    return output

output = run_test(model, inputs, 'eager')
c_output = run_test(model, inputs, 'inductor')
fp64 = run_test(model.to(dtype=torch.float64), [inputs[0].to(dtype=torch.float64)], 'eager')

print(torch.allclose(output, c_output, rtol=1e-3, atol=1e-3))
print(torch.max(torch.abs(c_output - output)))
print(torch._dynamo.utils.same(output, c_output, fp64))

```

**logs:**
- Before
```
False
tensor(0.0052)
False
```

- After
```
True
tensor(0.0004)
True
```

-
**Generated code:**
- Before
```
cpp_fused_mean_mul_ones_like_view_0 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       const float* in_ptr0)
{
    auto out_ptr0 = in_out_ptr0;
    #pragma omp parallel num_threads(240)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(98304L); x0+=static_cast<int64_t>(16L))
            {
                {
                    float tmp_acc0 = 0;
                    at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(98304L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                                auto tmp1 = static_cast<float>(1.0);
                                auto tmp2 = at::vec::Vectorized<float>(tmp1);
                                auto tmp3 = tmp0 * tmp2;
                                tmp_acc0_vec = tmp_acc0_vec + tmp3;
                            }
                        }
                    }
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                    {
                        tmp_acc0_vec.store(out_ptr0 + static_cast<int64_t>(x0));
                    }
                }
                {
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                        auto tmp1 = static_cast<float>(98304.0);
                        auto tmp2 = at::vec::Vectorized<float>(tmp1);
                        auto tmp3 = tmp0 / tmp2;
                        tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (3, 8, 64, 64), (32768, 4096, 64, 1))
        buf0 = empty_strided_cpu((98304, ), (1, ), torch.float32)
        buf1 = buf0; del buf0  # reuse
        # [Provenance debug handles] cpp_fused_mean_mul_ones_like_view_0:1
        cpp_fused_mean_mul_ones_like_view_0(buf1, arg0_1)
        del arg0_1
        return (buf1, )
```

- After
```
cpp_fused_mean_mul_ones_like_view_0 = async_compile.cpp_pybinding(['float*', 'const float*'], '''
#include <torch/csrc/inductor/cpp_prefix.h>
extern "C"  void  kernel(float* in_out_ptr0,
                       const float* in_ptr0)
{
    auto out_ptr0 = in_out_ptr0;
    #pragma omp parallel num_threads(240)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(98304L); x0+=static_cast<int64_t>(16L))
            {
                {
                    float tmp_acc0 = 0;
                    at::vec::Vectorized<float> tmp_acc0_vec = at::vec::Vectorized<float>(0);
                    at::vec::Vectorized<float> masked_tmp_acc0_vec = at::vec::Vectorized<float>(0);
                    CascadeSumHelper<float, 4096> scalar_cascade_helper0(static_cast<int64_t>(98304L));
                    CascadeSumHelper<at::vec::Vectorized<float>, 4096> cascade_helper0(static_cast<int64_t>(98304L));
                    CascadeSumHelper<at::vec::Vectorized<float>, 4096> masked_cascade_helper0(static_cast<int64_t>(0L));
                    for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(98304L); x1+=static_cast<int64_t>(1L))
                    {
                        {
                            if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(in_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                                auto tmp1 = static_cast<float>(1.0);
                                auto tmp2 = at::vec::Vectorized<float>(tmp1);
                                auto tmp3 = tmp0 * tmp2;
                                tmp_acc0_vec = cascade_sum_combine(tmp3, &cascade_helper0);
                            }
                        }
                    }
                    tmp_acc0 = cascade_sum_final(&scalar_cascade_helper0);
                    tmp_acc0_vec = cascade_sum_final(&cascade_helper0);
                    masked_tmp_acc0_vec = cascade_sum_final(&masked_cascade_helper0);
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                    {
                        tmp_acc0_vec = tmp_acc0_vec + masked_tmp_acc0_vec;
                        tmp_acc0_vec.store(out_ptr0 + static_cast<int64_t>(x0));
                    }
                }
                {
                    if(C10_LIKELY(x0 >= static_cast<int64_t>(0) && x0 < static_cast<int64_t>(98304L)))
                    {
                        auto tmp0 = at::vec::Vectorized<float>::loadu(out_ptr0 + static_cast<int64_t>(x0), static_cast<int64_t>(16));
                        auto tmp1 = static_cast<float>(98304.0);
                        auto tmp2 = at::vec::Vectorized<float>(tmp1);
                        auto tmp3 = tmp0 / tmp2;
                        tmp3.store(in_out_ptr0 + static_cast<int64_t>(x0));
                    }
                }
            }
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

class Runner:
    def __init__(self, partitions):
        self.partitions = partitions

    def recursively_apply_fns(self, fns):
        new_callables = []
        for fn, c in zip(fns, self.partitions):
            new_callables.append(fn(c))
        self.partitions = new_callables

    def call(self, args):
        arg0_1, = args
        args.clear()
        assert_size_stride(arg0_1, (3, 8, 64, 64), (32768, 4096, 64, 1))
        buf0 = empty_strided_cpu((98304, ), (1, ), torch.float32)
        buf1 = buf0; del buf0  # reuse
        # [Provenance debug handles] cpp_fused_mean_mul_ones_like_view_0:1
        cpp_fused_mean_mul_ones_like_view_0(buf1, arg0_1)
        del arg0_1
        return (buf1, )
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163144
Approved by: https://github.com/CaoE, https://github.com/mingfeima, https://github.com/jansel
2025-11-11 09:25:00 +00:00
c297b02f12 [DTensor] statically_known_true for slice strategy (#166990)
Avoids data-dependent errors for out-of-bounds & redundant slice checks.

The sharding logic that immediately depends on this only checks for redundant slices, and is saying: "it's safe to reuse the input placements if a) the slicing dimension isn't sharded, or b) the slice is redundant, so just pretend this op didn't happen".

This has a slight effect on output placements, when a slice is performed on a shared dim, and dynamic shapes are involved (size/start/end/step). Now if the slice isn't obviously redundant, we won't immediately consider the input placements valid (even if they could be for very particular runtime shapes), and select strategies valid for the general case - in this case I guess unsharding the slicing dim.

For backed symbols, we could choose to recompile when the redundant case is hit, by switching to `guard_or_false`, but it's not obvious how desirable this is.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166990
Approved by: https://github.com/laithsakka
2025-11-11 08:04:09 +00:00
bd24774f50 [XPU][Test] Enable XPU tests in inductor/test_analysis.py (#166840)
This PR enables XPU devices in test_analysis.py.

For performance reason, it skips some slow tests, so a full scope should be enabled by using:

```
export PYTORCH_TEST_WTH_SLOW=1
```

**PR Stack:**

- https://github.com/pytorch/pytorch/pull/166840 : This PR enables the tests, ignores the tests that failed
- https://github.com/pytorch/pytorch/pull/166839 : This fixed the bug and enable the full tests for xpu

**Some skipped test time:**

```
test_augment_trace_against_flop_counter_maxat0_xpu_float16 [49.0863s]
test_augment_trace_against_flop_counter_maxat0_xpu_float32 [18.2268s]
test_augment_trace_against_flop_counter_maxat1_xpu_float16 [85.6549s]
test_augment_trace_against_flop_counter_maxat1_xpu_float32 [329.0832s]
test_augment_trace_against_flop_counter_maxat2_xpu_float16 [24.4825s]
test_augment_trace_against_flop_counter_maxat2_xpu_float32 [19.0688s]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166840
Approved by: https://github.com/guangyey, https://github.com/jansel
2025-11-11 07:49:07 +00:00
525eb9fab9 Fix command injection vulnerability in PCH compilation (#167502)
Fixed a command injection vulnerability in PreCompiled Header (PCH) compilation where extra_cflags were passed to subprocess with shell=True, allowing arbitrary command execution through malicious compiler flags.

Changed subprocess.check_output(pch_cmd, shell=True) to use shlex.split() to safely parse the command without shell interpretation. This prevents shell metacharacters (;, |, &, etc.) in extra_cflags from being executed as shell commands.

Added test case test_pch_command_injection that verifies:
1. PCH compilation attempts with malicious payloads in extra_cflags
2. Shell commands embedded in flags are not executed
3. Exploit file is not created, proving no shell execution occurred

Note: On RHEL/Fedora and other systems with versioned GCC compilers, the test depends on https://github.com/pytorch/pytorch/pull/167501 being merged first, otherwise the test will be skipped due to GCC detection issues.

Fixes #167480

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167502
Approved by: https://github.com/malfet
2025-11-11 07:27:41 +00:00
7886070fc5 Use stable topological sort in fuse_by_partitions (#167397)
legalize_graph() performs a topo sort that shuffles the nodes is a global way, making the result unpredictable.
We should avoid this in graph pass in general.

This problem is discovered when testing regional_inductor, a single fuse region trigger the global reordering.

Before
https://www.internalfb.com/intern/diffing/?before_paste_number=2029217728&after_paste_number=2029218006&regex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff

After
https://www.internalfb.com/intern/diffing/?paste_number=2029162294&regex_remove_pattern=&enable_regex_remove=0&strip_empty_lines=0&line_wrap=0&selected_tab=plain_diff

Left is gm before regional_inductor, right is after.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/167397
Approved by: https://github.com/ezyang
2025-11-11 07:14:02 +00:00
87d17e9dee [pallas backend] Implementing Strided/Scatter Access (#167426)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167426
Approved by: https://github.com/yarongmu-google, https://github.com/jansel
2025-11-11 06:32:25 +00:00
18 changed files with 637 additions and 100 deletions

View File

@ -18,8 +18,6 @@ Please report security issues using https://github.com/pytorch/pytorch/security/
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat

View File

@ -533,7 +533,7 @@ class DTensorExportTest(TestCase):
self.assertEqual(fn(z), gm(z)[0])
def test_dtensor_data_dependent_index(self):
def test_dtensor_data_dependent_index_and_slice(self):
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
class Foo(torch.nn.Module):
@ -546,6 +546,35 @@ class DTensorExportTest(TestCase):
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
class Bar(torch.nn.Module):
def forward(self, x):
val = torch.clamp(x.max(), min=1).item()
torch._check(val >= 1)
return x[:val]
x = torch.randint(1000, (4, 64, 16))
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
gm = _dynamo_graph_capture_for_export(Bar())(x_dt)
self.assertExpectedInline(
"""\
graph():
%l_flat_args_0_ : [num_users=2] = placeholder[target=arg_0]
%max_1 : [num_users=1] = call_method[target=max](args = (%l_flat_args_0_,), kwargs = {})
%clamp : [num_users=1] = call_function[target=torch.clamp](args = (%max_1,), kwargs = {min: 1})
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
%res : [num_users=2] = call_function[target=operator.getitem](args = (%l_flat_args_0_, slice(None, item, None)), kwargs = {})
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%res, _local_tensor), kwargs = {})
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
return (res,)""", # noqa: B950
str(gm.graph).strip(),
)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -335,6 +335,59 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
@requires_multigpu()
def test_new_event_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_event
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
e0_ind = new_event()
with torch.Stream(device="cuda:1"):
get_external_object_by_index(e0_ind).record()
e1_ind = new_event()
self.assertNotEqual(e0_ind, e1_ind)
self.assertNotEqual(
get_external_object_by_index(e0_ind),
get_external_object_by_index(e1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=event_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_new_stream_api(self) -> None:
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
from torch._dynamo.variables.streams import new_stream
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
s0_ind = new_stream()
s1_ind = new_stream()
self.assertNotEqual(s0_ind, s1_ind)
self.assertNotEqual(
get_external_object_by_index(s0_ind),
get_external_object_by_index(s1_ind),
)
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
gm.graph.call_function(
get_external_object_by_index, args=(1,), kwargs={}
)
return gm
@torch.compile(backend=stream_generation_backend)
def fn(x):
return x + 1
fn(torch.ones(2, 2, device="cuda:0"))
@requires_cuda
def test_stream_with_mutation(self):
def fn(x, y):
@ -386,7 +439,68 @@ class <lambda>(torch.nn.Module):
)
@requires_cuda
def test_stream_backward(self) -> None:
def test_stream_backward_simple(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
with s0:
y0 = 2 * x + y
with s2:
z = 2 * x + y
return y0, z
inp = (
torch.ones(2, 2, requires_grad=True) + 1,
torch.ones(2, 2, requires_grad=True),
)
expected = fn(*inp)
(
actual,
_,
fw_graphs,
bw_graphs,
) = extract_graph(fn, *inp)
self.assertEqual(len(fw_graphs), 1)
self.assertEqual(expected, actual)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
# Annotation: {'stream': 1}
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
# Annotation: {'stream': 0}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
return (add, add_1)
""",
)
actual[1].sum().backward()
self.assertExpectedInline(
print_graph(bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
# Annotation: {'stream': 0}
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
#
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
# Annotation: {'stream': 1}
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
#
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
return (add_3, add_2)
""",
)
@requires_cuda
def test_stream_backward_sync(self) -> None:
def fn(x, y):
s2 = torch.Stream()
s0 = torch.Stream()
@ -523,6 +637,23 @@ class <lambda>(torch.nn.Module):
torch.accelerator.set_stream(original_stream)
reset_user_object_tracking()
@requires_cuda
def test_run_opcheck_wait_record_stream(self):
from torch._dynamo.variables.streams import wait_stream
from torch.library import opcheck
s0 = torch.Stream()
s1 = torch.Stream()
s2 = torch.Stream()
store_user_object_weakrefs(s0, s1, s2)
sample_inputs = [
(0, 1),
(2, 0),
]
for args in sample_inputs:
opcheck(wait_stream, args)
@requires_cuda
def test_inductor_lowering(self):
with patch("torch._inductor.config.implicit_fallbacks", False):

View File

@ -274,7 +274,10 @@ class TestUtils(TestCase):
class TestAnalysis(TestCase):
@skipIf(not SM80OrLater, "Requires SM80")
@skipIf(
(not torch.xpu.is_available()) and (not SM80OrLater),
"Requires XPU or CUDA SM80",
)
def test_noop(self):
with (
patch("sys.stdout", new_callable=StringIO) as mock_stdout,
@ -283,7 +286,10 @@ class TestAnalysis(TestCase):
main()
self.assertEqual(mock_stdout.getvalue(), "")
@skipIf(not SM80OrLater, "Requires SM80")
@skipIf(
(not torch.xpu.is_available()) and (not SM80OrLater),
"Requires XPU or CUDA SM80",
)
@dtypes(torch.float, torch.double, torch.float16)
def test_diff(self, device, dtype):
"""
@ -334,7 +340,11 @@ class TestAnalysis(TestCase):
expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0]
verify_flops(self, expected_flops, out_profile)
@skipIf(not SM80OrLater, "Requires SM80")
@skipIf(
(not torch.xpu.is_available()) and (not SM80OrLater),
"Requires XPU or CUDA SM80",
)
@skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU")
@dtypes(torch.float, torch.double, torch.float16)
@parametrize(
"maxat",
@ -504,7 +514,11 @@ class TestAnalysis(TestCase):
self.assertTrue(seen_baddbmm)
self.assertTrue(seen_conv)
@skipIf(not SM80OrLater, "Requires SM80")
@skipIf(
(not torch.xpu.is_available()) and (not SM80OrLater),
"Requires XPU or CUDA SM80",
)
@skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU")
@dtypes(torch.float, torch.float16)
@parametrize(
"maxat",
@ -554,7 +568,10 @@ class TestAnalysis(TestCase):
if event["name"] == "triton_poi_fused_add_randn_sin_0":
event["args"]["kernel_num_gb"] = 0.002097168
@skipIf(not SM80OrLater, "Requires SM80")
@skipIf(
(not torch.xpu.is_available()) and (not SM80OrLater),
"Requires XPU or CUDA SM80",
)
@dtypes(torch.float, torch.float16)
def test_combine_profiles(self, device, dtype):
"""
@ -630,7 +647,10 @@ class TestAnalysis(TestCase):
# Verify device properties are present
self.assertIn("deviceProperties", combined_profile)
self.assertGreater(len(combined_profile["deviceProperties"]), 0)
# XPU currently does not have the deviceProperties like CUDA.
# See https://github.com/intel/torch-xpu-ops/issues/2247
if torch.cuda.is_available():
self.assertGreater(len(combined_profile["deviceProperties"]), 0)
# Verify some trace events from each original profile are present
combined_event_names = {
@ -648,7 +668,7 @@ class TestAnalysis(TestCase):
self.assertTrue(profile3_event_names.intersection(combined_event_names))
instantiate_device_type_tests(TestAnalysis, globals())
instantiate_device_type_tests(TestAnalysis, globals(), allow_xpu=True)
if __name__ == "__main__":
run_tests()

View File

@ -2617,7 +2617,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic_$is_traceable, m) {
)
def fn():
for i in [10, 100, 10, 20, 10]:
for i in [10, 30, 10, 20, 10]:
x = torch.ones(i, i, requires_grad=True)
out = module.custom_op_backed_by_autograd_fn(x)
loss = out.sum()

View File

@ -1988,6 +1988,20 @@ class CPUReproTests(TestCase):
def test_tile2d_store_channel_shuffle_cl_quant_output_int8(self):
self._test_tile2d_store_channel_shuffle_cl_quant_output_helper(torch.int8)
@requires_vectorization
def test_to_channels_last_fp8(self):
def fn(x):
return x.to(memory_format=torch.channels_last)
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
torch._dynamo.reset()
metrics.reset()
self.common(
fn,
(torch.randn(20, 16, 48, 48).to(dtype=dtype),),
)
check_metrics_vec_kernel_count(2)
def _test_dequant_relu_quant_dequant_relu_quant_lowering_helper(self, dtype):
def fn(
x,
@ -2729,6 +2743,18 @@ class CPUReproTests(TestCase):
actual = torch.compile(op)(t)
self.assertEqual(expected, actual)
def test_outer_mean_large_size(self):
def fn(x):
x = x.flatten()
x_one = torch.ones_like(x)
x = torch.outer(x, x_one)
return torch.mean(x, dim=1)
x = torch.randn(2, 2, 64, 64)
expected = fn(x)
actual = torch.compile(fn)(x)
self.assertEqual(expected, actual, atol=1e-4, rtol=1e-4)
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
@requires_vectorization
@patch("torch.cuda.is_available", lambda: False)

View File

@ -42,6 +42,7 @@ def make_pallas(cls):
cls,
cls_prefix,
suffix,
(config, "cpu_backend", "pallas"),
(config, "cuda_backend", "pallas"),
xfail_prop="_expected_failure_pallas",
)
@ -336,6 +337,48 @@ class PallasTestsMixin:
expected = operate_on_tensor(x_t_contiguous)
self.assertEqual(result, expected)
def test_strided_int_pallas(self):
"""Test strided access patterns with the Pallas backend."""
def fn(x):
# Access every other element (strided access)
return x[::2] * 2.0
compiled = self._compile(fn)
x = torch.arange(16, dtype=torch.float32, device=self.DEVICE)
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
def test_strided_offset_pallas(self):
"""Test strided access with offset."""
def fn(x):
# Access every other element starting from index 1
return x[1::2] + 1.0
compiled = self._compile(fn)
x = torch.arange(16, dtype=torch.float32, device=self.DEVICE)
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
def test_strided_2d_pallas(self):
"""Test strided access on 2D tensors."""
def fn(x):
# Simple operation on 2D tensor
return x * 3.0
compiled = self._compile(fn)
x = torch.randn(8, 16, device=self.DEVICE)
result = compiled(x)
expected = fn(x)
self.assertEqual(result, expected)
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
class PallasTestsCUDA(PallasTestsMixin, TestCase):
@ -347,14 +390,16 @@ class PallasTestsCPU(PallasTestsMixin, TestCase):
DEVICE = "cpu"
# Create test variants using the main test suite
# Note: Only enable GPU tests since Pallas primarily targets GPU
if hasattr(sys.modules.get(__name__), "test_torchinductor") and HAS_PALLAS:
if getattr(test_torchinductor, "HAS_GPU", False):
# Uncomment these to run full test suite with Pallas backend
# make_pallas(test_torchinductor.SweepInputsGPUTest)
# make_pallas(test_torchinductor.GPUTests)
pass
if test_torchinductor.HAS_CPU and HAS_PALLAS:
make_pallas(test_torchinductor.SweepInputsCpuTest)
# make_pallas(test_torchinductor.CpuTests)
if test_torchinductor.HAS_GPU and HAS_PALLAS:
# make_pallas(test_torchinductor.SweepInputsGPUTest)
# make_pallas(test_torchinductor.GPUTests)
pass
if __name__ == "__main__":
if HAS_PALLAS:

View File

@ -1309,6 +1309,40 @@ class TestCppExtensionJIT(common.TestCase):
# test if build was successful
self.assertEqual(success, True)
@unittest.skipIf(
not IS_LINUX or not check_compiler_is_gcc(get_cxx_compiler()),
"PCH is only available on Linux with GCC",
)
def test_pch_command_injection(self):
"""Tests that PCH compilation is not vulnerable to command injection."""
with tempfile.TemporaryDirectory() as tmpdir:
exploit_file = os.path.join(tmpdir, "pch_exploit")
# If executed by shell, this would create exploit_file
payload = f'; echo vulnerable > "{exploit_file}"'
cpp_source = "void foo() {}"
# Try to compile with malicious payload in extra_cflags
# The compilation may succeed or fail, but the key test is whether
# the shell command in the payload gets executed
try:
torch.utils.cpp_extension.load_inline(
name="test_pch_injection",
cpp_sources=cpp_source,
functions=["foo"],
extra_cflags=[payload],
use_pch=True,
verbose=True,
)
except RuntimeError:
# Compilation failure is expected since payload is not a valid flag
pass
# The critical security check: verify the shell command was NOT executed
self.assertFalse(
os.path.exists(exploit_file),
"Command injection vulnerability detected!",
)
if __name__ == "__main__":
common.run_tests()

View File

@ -10,7 +10,10 @@ from torch.fx import has_side_effect, Proxy
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..exc import TYPE_CHECKING, unimplemented
from ..graph_bytecode_inputs import get_external_object_by_index
from ..graph_bytecode_inputs import (
get_external_object_by_index,
register_graph_created_object,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import FxTracebackAnnotateVariable
@ -28,6 +31,26 @@ from torch._library.custom_ops import custom_op
Tensor = torch.Tensor
def new_event(*args: Any, **kwargs: Any) -> int:
event = torch.Event(*args, **kwargs)
return register_graph_created_object(
event,
EventVariable.make_construct_in_graph_event_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
return register_graph_created_object(
stream,
StreamVariable.make_construct_in_graph_stream_fn(
TupleVariable([]), ConstDictVariable({})
),
)
def _get_stream_by_index(index: int) -> torch.Stream:
stream = get_external_object_by_index(index)
assert isinstance(stream, torch.Stream), (
@ -115,6 +138,24 @@ def _(
has_side_effect(torch.ops.streams.wait_event.default)
@custom_op("streams::wait_stream", mutates_args=())
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
waiting = _get_stream_by_index(waiting_stream_index)
waited_on = _get_stream_by_index(waited_on_stream_index)
waiting.wait_stream(waited_on)
@wait_stream.register_fake
def _(
event_index: int,
stream_index: int,
) -> None:
pass
has_side_effect(torch.ops.streams.wait_stream.default)
class SymbolicStreamState:
"""Track the currently entered stream if any"""

View File

@ -33,6 +33,7 @@ from .graph_capture_wrappers import (
handle_effect_tokens_fn,
)
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
from .streams import assign_backward_streams
from .utils import (
call_and_expect_output_descs,
copy_fwd_metadata_to_bw_nodes,
@ -473,6 +474,10 @@ def aot_dispatch_autograd_graph(
# fw node match might be erased
copy_fwd_metadata_to_bw_nodes(fx_g)
# After copying metadata, assign streams to gradient accumulation
# nodes and insert syncrhonization
assign_backward_streams(fx_g)
fx_g.graph.eliminate_dead_code()
if not aot_config.disable_functionalization:
# There should be *NO* mutating ops in the graph at this point.

View File

@ -0,0 +1,46 @@
import torch.fx
from torch._dynamo.graph_utils import _get_flat_args
from .utils import _is_backward_node_with_seq_nr, _is_forward_node_with_seq_nr
Node = torch.fx.Node
def seq_number(node: Node) -> int:
assert "seq_nr" in node.meta, "No seq nr in seq_number call"
return node.meta.get("seq_nr") # type: ignore[return-type]
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
"""Assigns backward streams to gradient accumulation nodes"""
max_fw_seq_nr = -1
max_bw_seq_nr = -1
bw_nodes = []
for node in gm.graph.nodes:
if _is_forward_node_with_seq_nr(node):
max_fw_seq_nr = max(max_fw_seq_nr, seq_number(node))
elif _is_backward_node_with_seq_nr(node):
bw_nodes.append(node)
max_bw_seq_nr = max(max_bw_seq_nr, seq_number(node))
if max_bw_seq_nr > max_fw_seq_nr:
# in this case, there are some gradient accumulation nodes
# these nodes will need stream assignments
for node in bw_nodes:
if seq_number(node) == max_bw_seq_nr:
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
# 1. If the device of the gradient is the same as the device of the consumer,
# then the accumulation stream is the consumer node's stream.
# 2. If the device of the gradient matches the device of the producer,
# then accumulation stream is the producer node's stream.
# 3. If neither is true, pick the current stream of the device of the gradient.
# Accumulation stream synchronization:
# Prior to accumulation, have the accumulation stream wait for producer stream
# and the stashed event (recorded on the previous producer stream).
gradients = _get_flat_args(node, {})
def insert_sync(producer, consumer) -> None:
pass

View File

@ -2205,28 +2205,22 @@ class CppKernel(Kernel):
reduction_size = functools.reduce(
operator.mul, self.call_ranges[self.reduction_depth :]
)
if config.cpp.dynamic_threads:
# If dynamic threads, to be conservative,
# use reduction_size as the range size
rt_size = reduction_size
else:
rt_size = CeilDiv(reduction_size, parallel_num_threads())
# chunk size to balance accuracy and performance
chunk_size = 2**20
chunk_size = 4096
# use acc helper If cannot get size_hint
try:
rt_size_hint = V.graph.sizevars.size_hint(rt_size)
reduction_size_hint = V.graph.sizevars.size_hint(reduction_size)
except Exception:
return True
if rt_size_hint > chunk_size:
if reduction_size_hint > chunk_size:
# use helper if the reduction size is too large
V.graph.sizevars.check_lt(chunk_size, rt_size)
V.graph.sizevars.check_lt(chunk_size, reduction_size)
return True
else:
V.graph.sizevars.check_leq(rt_size, chunk_size)
V.graph.sizevars.check_leq(reduction_size, chunk_size)
return False
def _acc_helper_init(
@ -2243,7 +2237,7 @@ class CppKernel(Kernel):
)
num_range_thread_expr = cexpr_index(num_range_thread)
assert reduction_type in ["welford_reduce", "sum"]
chunk_size = 4096 if reduction_type == "welford_reduce" else 2**20
chunk_size = 4096
num_chunks = CeilDiv(num_range_thread, chunk_size)
helper_type = (
"WelfordHelper"
@ -3690,6 +3684,8 @@ class CppTile2DKernel(CppVecKernel):
if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [
torch.uint8,
torch.int8,
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});"
else:

View File

@ -10,10 +10,12 @@ import torch # noqa: TC001
from torch.utils._ordered_set import OrderedSet
from .. import config
from ..runtime.runtime_utils import torch_dtype_to_jax
from ..utils import get_fused_kernel_name, get_kernel_metadata
from ..virtualized import V
from .block_analysis import BlockPatternMatcher
from .common import BackendFeature, CSEVariable, IndentedBuffer, OpOverrides
from .simd import SIMDKernel, SIMDScheduling
from .simd import pexpr, SIMDKernel, SIMDScheduling
if TYPE_CHECKING:
@ -187,63 +189,168 @@ class PallasKernelOverrides(OpOverrides):
def where(cond: str, a: str, b: str) -> str:
return f"jnp.where({cond}, {a}, {b})"
@staticmethod
def to_dtype(
x: str,
dtype: torch.dtype,
src_dtype: Optional[torch.dtype] = None,
use_compute_types: bool = True,
) -> str:
jax_dtype = torch_dtype_to_jax(dtype)
return f"{x}.astype({jax_dtype})"
class PallasKernel(SIMDKernel):
"""
Minimal Pallas kernel for simple elementwise operations.
Pallas kernel for elementwise operations with support for strided/scatter access.
Strategy:
- Treat loads as full-array refs: "in_ptrX[...]"
- Convert index expressions to JAX-compatible array slicing
- Load/store using indexed access: "in_ptrX[slice]" or full-array "in_ptrX[...]"
- Compute expression with Python operators (compatible with jax.numpy broadcasting)
- Store as full-array ref assignment: "out_ptrY[...] = <expr>"
- Generate Python code that defines a Pallas kernel and a host entrypoint.
- Use async_compile.pallas path to compile and load Python code.
"""
overrides = PallasKernelOverrides # type: ignore[assignment]
kexpr: Callable[[sympy.Expr], str] = pexpr # Use Python expression printer
def _get_contiguous_index_str(self, index: sympy.Expr) -> str:
def _get_index_str(self, index: sympy.Expr) -> str:
"""
Validate that the index represents contiguous access and return the indexing string.
Convert an index expression to a string suitable for Pallas indexing.
For Pallas, we only support simple contiguous access patterns where the index
is a single symbol (e.g., xindex) representing a flattened iteration.
This ensures the load/store order is contiguous.
Pallas operates on full arrays, so we need to convert index expressions
to JAX array slicing. For example:
- x0 -> "..." (contiguous access, full array)
- 2*x0 -> "::2" (strided access with stride 2)
- 2*x0 + 1 -> "1::2" (strided access with offset 1, stride 2)
Args:
index: The indexing expression to validate
index: The indexing expression to convert
Returns:
The indexing string to use (currently always "...")
Raises:
Unsupported: If the index is not a simple contiguous pattern
The indexing string to use in generated code
"""
# Prepare and simplify the index
prepared_index = self.prepare_indexing(index)
# For contiguous access, we expect a single symbol (like xindex)
# or a simple integer (for scalar operations)
# For simple single-symbol access (contiguous case), we can use [...]
# which is more efficient as it operates on the entire array at once
if isinstance(prepared_index, sympy.Symbol):
# This is the expected case: a single symbol representing contiguous iteration
return "..."
elif prepared_index.is_Integer:
# Scalar case
return "..."
# Scalar index
return str(prepared_index)
else:
# If there's any complex expression (ModularIndexing, FloorDiv, etc.),
# it's not a simple contiguous pattern
raise Unsupported(
f"Pallas backend only supports contiguous access patterns. "
f"Got complex index: {prepared_index}"
)
# Complex expression (strided/scatter access)
# Try to extract stride and offset for common patterns
return self._convert_to_jax_slice(prepared_index)
def _convert_to_jax_slice(self, index: sympy.Expr) -> str:
"""
Convert a sympy index expression to JAX slice notation.
Handles common patterns like:
- stride*var -> ::stride
- stride*var + offset -> offset::stride
For more complex patterns, falls back to explicit indexing.
Uses BlockPatternMatcher for robust pattern matching.
"""
# Get the iteration variables for this kernel
if not self.range_trees:
return "..."
# Simplify the index
index = V.graph.sizevars.simplify(index)
free_symbols = index.free_symbols
# Get iteration variables from range_tree_nodes
iter_vars = OrderedSet(self.range_tree_nodes.keys())
# Find which iteration variable(s) are used
used_vars = free_symbols & iter_vars
if len(used_vars) == 0:
# No iteration variables, this is a constant index
return str(index)
elif len(used_vars) == 1:
# Single iteration variable - try to extract stride and offset using BlockPatternMatcher
var = next(iter(used_vars))
# Get the subexpression involving this variable
var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
# Try to match affine pattern: stride * var
stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
if stride is not None:
# Extract the constant offset (terms not involving var)
offset = index - var_expr
offset = V.graph.sizevars.simplify(offset)
# Generate JAX slice notation
if stride == 1 and offset == 0:
# Contiguous access
return "..."
elif offset == 0:
# Pure stride: ::stride
stride_str = self.kexpr(stride)
return f"::{stride_str}"
else:
# Offset + stride: offset::stride
offset_str = self.kexpr(offset)
stride_str = self.kexpr(stride)
return f"{offset_str}::{stride_str}"
else:
# Couldn't match affine pattern, fall back to original logic
offset = index - var_expr
offset = V.graph.sizevars.simplify(offset)
if offset == 0 and var_expr == var:
# Just the variable itself, unit stride
return "..."
elif len(used_vars) > 1:
# Multi-dimensional indexing
# For contiguous multi-dim access, all terms should have unit stride
all_unit_stride = True
for var in used_vars:
var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
if stride != 1:
all_unit_stride = False
break
if all_unit_stride:
# Contiguous multi-dimensional access
return "..."
else:
# Strided multi-dimensional access - requires advanced indexing
# For now, use ellipsis which may work for many cases
# TODO: Implement proper multi-dimensional strided indexing
return "..."
# For complex cases, raise an error
return self._generate_index_array(index)
def _generate_index_array(self, index: sympy.Expr) -> str:
"""
Generate JAX code to compute an index array for complex indexing patterns.
For very complex patterns that can't be expressed as simple slices,
we need to compute the indices explicitly. This is not yet fully implemented.
"""
# For now, raise an error for complex patterns
# TODO: Implement advanced indexing support
raise Unsupported(
f"Pallas backend does not yet support complex indexing pattern: {index}"
)
def load(self, name: str, index: sympy.Expr) -> CSEVariable: # type: ignore[override]
buf = self.args.input(name)
dtype = V.graph.get_dtype(name)
# Validate contiguous access and get index string
index_str = self._get_contiguous_index_str(index)
# Pallas refs must be unpacked with [...] to load the array
# Get index string for load operation
index_str = self._get_index_str(index)
# Pallas refs must be unpacked with [...] or [index] to load
return self.cse.generate(
self.compute,
f"{buf}[{index_str}]",
@ -257,9 +364,9 @@ class PallasKernel(SIMDKernel):
raise Unsupported("pallas store mode not supported")
out = self.args.output(name)
self.store_buffer_names.add(name)
# Validate contiguous access and get index string
index_str = self._get_contiguous_index_str(index)
# Pallas refs must use [...] assignment to store back to the ref
# Get index string for store operation
index_str = self._get_index_str(index)
# Pallas refs must use [...] or [index] assignment to store
self.stores.writeline(f"{out}[{index_str}] = {value}")
def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[override]
@ -284,6 +391,11 @@ class PallasKernel(SIMDKernel):
"Pallas backend currently supports single-output elementwise kernels only"
)
# Get output dtype at compile time
output_name = live_outs[0]
output_dtype = V.graph.get_dtype(output_name)
output_dtype_jax = torch_dtype_to_jax(output_dtype)
code = IndentedBuffer()
code.splice(
"""
@ -307,7 +419,10 @@ class PallasKernel(SIMDKernel):
)
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
with code.indent():
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
# Emit compute (CSE) and store lines; they reference *_ptr[index] directly
# The iteration variables are implicitly handled by JAX's vectorization
# When using [...], it processes the whole array
# When using explicit indices, they should be JAX-traced values
for line in self.compute._lines:
code.writeline(str(line))
for line in self.stores._lines:
@ -329,6 +444,9 @@ class PallasKernel(SIMDKernel):
main_name = f"{kernel_name}_main"
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
with code.indent():
# Enable JAX x64 mode to support float64/int64 types
code.writeline("# Enable JAX x64 mode for float64/int64 support")
code.writeline("jax.config.update('jax_enable_x64', True)")
# Identify inputs (in_ptr*) and output (out_ptr*)
input_params = [
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
@ -343,24 +461,15 @@ class PallasKernel(SIMDKernel):
output_param = output_params[0]
# Convert inputs to JAX arrays
code.writeline("# Convert Torch -> JAX for inputs")
for inp in input_params:
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
code.writeline(
f"{inp}_jax = jax.dlpack.from_dlpack({inp}.contiguous())"
)
# Get output metadata from PyTorch tensor
code.writeline("# Prepare output metadata from PyTorch tensor")
code.writeline("# Map PyTorch dtype to JAX dtype")
code.writeline("_torch_dtype_to_jax = {")
code.writeline(
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
)
code.writeline(
" torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8,"
)
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
code.writeline("}")
code.writeline(f"out_shape = tuple({output_param}.shape)")
code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]")
code.writeline(f"out_dtype = {output_dtype_jax}")
call_args = ["out_shape", "out_dtype"] + [
f"{inp}_jax" for inp in input_params

View File

@ -187,3 +187,31 @@ def compile_mps_shader(source: str) -> Any:
return torch.mps.compile_shader(source)
except SyntaxError as err:
raise SyntaxError(f"failed to compile {source} with {err.msg}") from err
def torch_dtype_to_jax(dtype: torch.dtype) -> str:
"""
Map PyTorch dtype to JAX dtype expression.
This helper is used at compile time in codegen to generate
JAX dtype expressions for Pallas kernels.
Args:
dtype: PyTorch dtype to convert
Returns:
JAX dtype expression as string (e.g., "jnp.float32")
"""
dtype_map = {
torch.float32: "jnp.float32",
torch.float64: "jnp.float64",
torch.float16: "jnp.float16",
torch.bfloat16: "jnp.bfloat16",
torch.int32: "jnp.int32",
torch.int64: "jnp.int64",
torch.int16: "jnp.int16",
torch.int8: "jnp.int8",
torch.uint8: "jnp.uint8",
torch.bool: "jnp.bool_",
}
return dtype_map.get(dtype, f"jnp.{dtype}")

View File

@ -36,6 +36,7 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
from torch.fx.experimental.symbolic_shapes import statically_known_true
aten = torch.ops.aten
@ -381,7 +382,7 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType:
raise AssertionError(f"Expected int, got {type(dim)}")
if start is None:
start = 0
if end is None or end > input_shape[dim]:
if end is None or statically_known_true(end > input_shape[dim]):
end = input_shape[dim]
if not isinstance(start, IntLike):
raise AssertionError(f"Expected IntLike, got {type(start)}")
@ -395,13 +396,20 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType:
start = normalize_dim(start, input_shape[dim]) # type: ignore[arg-type]
end = normalize_dim(end, input_shape[dim]) # type: ignore[arg-type]
redundant_slice = start == 0 and end == input_shape[dim] and step == 1
statically_redundant_slice = (
statically_known_true(start == 0)
and statically_known_true(end == input_shape[dim])
and statically_known_true(step == 1)
)
slice_strategy = OpStrategy([])
for arg_strategy in input_strategy.strategies:
arg_spec = arg_strategy.output_spec
if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice:
if (
not is_tensor_dim_sharded(arg_spec, dim=slice_dim)
or statically_redundant_slice
):
# only add the strategy if the slice dim is not sharded
out_spec = DTensorSpec(mesh, arg_spec.placements)
slice_strategy.strategies.append(

View File

@ -245,7 +245,9 @@ class FxNetAccFusionsFinder:
@compatibility(is_backward_compatible=False)
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def legalize_graph(
gm: torch.fx.GraphModule, stable_topo_sort: bool = False
) -> torch.fx.GraphModule:
"""
Replace the graph of the given GraphModule with one that contains the same nodes as the
original, but in topologically sorted order.
@ -255,6 +257,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
Arguments:
gm: The graph module to topologically sort. It is modified in-place.
stable_topo_sort: when True, PRIORITIZED_OPS would be ignored.
Returns:
The graph module in-place sorted
@ -304,7 +307,11 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for user in cur.users:
indeg[user] -= 1
if indeg[user] == 0:
if user.op == "call_function" and user.target in PRIORITIZED_OPS:
if (
not stable_topo_sort
and user.op == "call_function"
and user.target in PRIORITIZED_OPS
):
queue.appendleft(user)
else:
queue.append(user)

View File

@ -220,22 +220,36 @@ def insert_subgm(
submodule_name = sub_gm.__class__.__name__
gm.add_submodule(submodule_name, sub_gm)
def last_node(target_nodes: tuple[Node, ...]) -> Node | None:
for node in reversed(gm.graph.nodes):
if node in target_nodes:
return node
return None
last_input_node: Node | None = last_node(orig_inputs)
assert last_input_node is not None
# Create a call_module node in main graph.
module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None)
output_node = sub_gm.graph.output_node()
if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple):
# main_remapping[comp.orig_outputs[0]] = module_node
orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
else:
for i, orig_output in enumerate(orig_outputs):
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
module_node.meta["val"] = tuple(
orig_output.meta.get("val", None) for orig_output in orig_outputs
with gm.graph.inserting_after(last_input_node):
module_node = gm.graph.call_module(
submodule_name, args=orig_inputs, kwargs=None
)
output_node = sub_gm.graph.output_node()
next_node = module_node.next
with gm.graph.inserting_before(next_node):
if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple):
# main_remapping[comp.orig_outputs[0]] = module_node
orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
else:
for i, orig_output in enumerate(orig_outputs):
# Use Proxy to record getitem access.
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
module_node.meta["val"] = tuple(
orig_output.meta.get("val", None) for orig_output in orig_outputs
)
return gm
@ -269,7 +283,7 @@ def fuse_by_partitions(
erase_nodes(gm, sorted_nodes)
# topological sort original gm with newly created sub_gm
legalize_graph(gm)
legalize_graph(gm, stable_topo_sort=True)
gm.graph.lint()
return gm

View File

@ -1833,7 +1833,7 @@ def _check_and_build_extension_h_precompiler_headers(
def build_precompile_header(pch_cmd) -> None:
try:
subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT)
subprocess.check_output(shlex.split(pch_cmd), stderr=subprocess.STDOUT)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}") from e