mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 06:44:55 +08:00
Compare commits
11 Commits
malfet-pat
...
mlazos/bwd
| Author | SHA1 | Date | |
|---|---|---|---|
| e457003553 | |||
| 5f990d00e4 | |||
| 706d566e2c | |||
| 3fdee99d7c | |||
| daed97afff | |||
| 53947adb1f | |||
| c297b02f12 | |||
| bd24774f50 | |||
| 525eb9fab9 | |||
| 7886070fc5 | |||
| 87d17e9dee |
@ -18,8 +18,6 @@ Please report security issues using https://github.com/pytorch/pytorch/security/
|
||||
|
||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||
|
||||
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
|
||||
|
||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||
|
||||
https://www.facebook.com/whitehat
|
||||
|
||||
@ -533,7 +533,7 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
self.assertEqual(fn(z), gm(z)[0])
|
||||
|
||||
def test_dtensor_data_dependent_index(self):
|
||||
def test_dtensor_data_dependent_index_and_slice(self):
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
@ -546,6 +546,35 @@ class DTensorExportTest(TestCase):
|
||||
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
val = torch.clamp(x.max(), min=1).item()
|
||||
torch._check(val >= 1)
|
||||
return x[:val]
|
||||
|
||||
x = torch.randint(1000, (4, 64, 16))
|
||||
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
|
||||
gm = _dynamo_graph_capture_for_export(Bar())(x_dt)
|
||||
self.assertExpectedInline(
|
||||
"""\
|
||||
graph():
|
||||
%l_flat_args_0_ : [num_users=2] = placeholder[target=arg_0]
|
||||
%max_1 : [num_users=1] = call_method[target=max](args = (%l_flat_args_0_,), kwargs = {})
|
||||
%clamp : [num_users=1] = call_function[target=torch.clamp](args = (%max_1,), kwargs = {min: 1})
|
||||
%item : [num_users=2] = call_method[target=item](args = (%clamp,), kwargs = {})
|
||||
%ge_1 : [num_users=1] = call_function[target=operator.ge](args = (%item, 1), kwargs = {})
|
||||
%_assert_scalar_default : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_1, Runtime assertion failed for expression u0 >= 1 on node 'ge_1'), kwargs = {})
|
||||
%res : [num_users=2] = call_function[target=operator.getitem](args = (%l_flat_args_0_, slice(None, item, None)), kwargs = {})
|
||||
%getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%res, _local_tensor), kwargs = {})
|
||||
%sym_size_int : [num_users=2] = call_function[target=torch.ops.aten.sym_size.int](args = (%getattr_1, 0), kwargs = {})
|
||||
%ge_2 : [num_users=1] = call_function[target=operator.ge](args = (%sym_size_int, 0), kwargs = {})
|
||||
%_assert_scalar_default_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%ge_2, Runtime assertion failed for expression u2 >= 0 on node 'ge_2'), kwargs = {})
|
||||
%le : [num_users=1] = call_function[target=operator.le](args = (%sym_size_int, 4), kwargs = {})
|
||||
%_assert_scalar_default_2 : [num_users=0] = call_function[target=torch.ops.aten._assert_scalar.default](args = (%le, Runtime assertion failed for expression u2 <= 4 on node 'le'), kwargs = {})
|
||||
return (res,)""", # noqa: B950
|
||||
str(gm.graph).strip(),
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
||||
@ -335,6 +335,59 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
@requires_multigpu()
|
||||
def test_new_event_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_event
|
||||
|
||||
def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
e0_ind = new_event()
|
||||
with torch.Stream(device="cuda:1"):
|
||||
get_external_object_by_index(e0_ind).record()
|
||||
e1_ind = new_event()
|
||||
self.assertNotEqual(e0_ind, e1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(e0_ind),
|
||||
get_external_object_by_index(e1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=event_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_new_stream_api(self) -> None:
|
||||
from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index
|
||||
from torch._dynamo.variables.streams import new_stream
|
||||
|
||||
def stream_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
s0_ind = new_stream()
|
||||
s1_ind = new_stream()
|
||||
self.assertNotEqual(s0_ind, s1_ind)
|
||||
self.assertNotEqual(
|
||||
get_external_object_by_index(s0_ind),
|
||||
get_external_object_by_index(s1_ind),
|
||||
)
|
||||
with gm.graph.inserting_after(next(iter(gm.graph.nodes))):
|
||||
gm.graph.call_function(
|
||||
get_external_object_by_index, args=(1,), kwargs={}
|
||||
)
|
||||
return gm
|
||||
|
||||
@torch.compile(backend=stream_generation_backend)
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
fn(torch.ones(2, 2, device="cuda:0"))
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_with_mutation(self):
|
||||
def fn(x, y):
|
||||
@ -386,7 +439,68 @@ class <lambda>(torch.nn.Module):
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward(self) -> None:
|
||||
def test_stream_backward_simple(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
with s0:
|
||||
y0 = 2 * x + y
|
||||
with s2:
|
||||
z = 2 * x + y
|
||||
|
||||
return y0, z
|
||||
|
||||
inp = (
|
||||
torch.ones(2, 2, requires_grad=True) + 1,
|
||||
torch.ones(2, 2, requires_grad=True),
|
||||
)
|
||||
expected = fn(*inp)
|
||||
(
|
||||
actual,
|
||||
_,
|
||||
fw_graphs,
|
||||
bw_graphs,
|
||||
) = extract_graph(fn, *inp)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertEqual(expected, actual)
|
||||
self.assertExpectedInline(
|
||||
print_graph(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 1}
|
||||
mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2)
|
||||
|
||||
# Annotation: {'stream': 0}
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None
|
||||
return (add, add_1)
|
||||
""",
|
||||
)
|
||||
|
||||
actual[1].sum().backward()
|
||||
self.assertExpectedInline(
|
||||
print_graph(bw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"):
|
||||
# Annotation: {'stream': 0}
|
||||
mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2)
|
||||
|
||||
#
|
||||
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None
|
||||
|
||||
# Annotation: {'stream': 1}
|
||||
mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
#
|
||||
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add_3, add_2)
|
||||
""",
|
||||
)
|
||||
|
||||
@requires_cuda
|
||||
def test_stream_backward_sync(self) -> None:
|
||||
def fn(x, y):
|
||||
s2 = torch.Stream()
|
||||
s0 = torch.Stream()
|
||||
@ -523,6 +637,23 @@ class <lambda>(torch.nn.Module):
|
||||
torch.accelerator.set_stream(original_stream)
|
||||
reset_user_object_tracking()
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck_wait_record_stream(self):
|
||||
from torch._dynamo.variables.streams import wait_stream
|
||||
from torch.library import opcheck
|
||||
|
||||
s0 = torch.Stream()
|
||||
s1 = torch.Stream()
|
||||
s2 = torch.Stream()
|
||||
store_user_object_weakrefs(s0, s1, s2)
|
||||
|
||||
sample_inputs = [
|
||||
(0, 1),
|
||||
(2, 0),
|
||||
]
|
||||
for args in sample_inputs:
|
||||
opcheck(wait_stream, args)
|
||||
|
||||
@requires_cuda
|
||||
def test_inductor_lowering(self):
|
||||
with patch("torch._inductor.config.implicit_fallbacks", False):
|
||||
|
||||
@ -274,7 +274,10 @@ class TestUtils(TestCase):
|
||||
|
||||
|
||||
class TestAnalysis(TestCase):
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@skipIf(
|
||||
(not torch.xpu.is_available()) and (not SM80OrLater),
|
||||
"Requires XPU or CUDA SM80",
|
||||
)
|
||||
def test_noop(self):
|
||||
with (
|
||||
patch("sys.stdout", new_callable=StringIO) as mock_stdout,
|
||||
@ -283,7 +286,10 @@ class TestAnalysis(TestCase):
|
||||
main()
|
||||
self.assertEqual(mock_stdout.getvalue(), "")
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@skipIf(
|
||||
(not torch.xpu.is_available()) and (not SM80OrLater),
|
||||
"Requires XPU or CUDA SM80",
|
||||
)
|
||||
@dtypes(torch.float, torch.double, torch.float16)
|
||||
def test_diff(self, device, dtype):
|
||||
"""
|
||||
@ -334,7 +340,11 @@ class TestAnalysis(TestCase):
|
||||
expected_flops = [4096000, 4096000, 223552896, 223552896, 0, 0, 0]
|
||||
verify_flops(self, expected_flops, out_profile)
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@skipIf(
|
||||
(not torch.xpu.is_available()) and (not SM80OrLater),
|
||||
"Requires XPU or CUDA SM80",
|
||||
)
|
||||
@skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU")
|
||||
@dtypes(torch.float, torch.double, torch.float16)
|
||||
@parametrize(
|
||||
"maxat",
|
||||
@ -504,7 +514,11 @@ class TestAnalysis(TestCase):
|
||||
self.assertTrue(seen_baddbmm)
|
||||
self.assertTrue(seen_conv)
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@skipIf(
|
||||
(not torch.xpu.is_available()) and (not SM80OrLater),
|
||||
"Requires XPU or CUDA SM80",
|
||||
)
|
||||
@skipXPUIf(TEST_WITH_SLOW, "Skip because test too slow on XPU")
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@parametrize(
|
||||
"maxat",
|
||||
@ -554,7 +568,10 @@ class TestAnalysis(TestCase):
|
||||
if event["name"] == "triton_poi_fused_add_randn_sin_0":
|
||||
event["args"]["kernel_num_gb"] = 0.002097168
|
||||
|
||||
@skipIf(not SM80OrLater, "Requires SM80")
|
||||
@skipIf(
|
||||
(not torch.xpu.is_available()) and (not SM80OrLater),
|
||||
"Requires XPU or CUDA SM80",
|
||||
)
|
||||
@dtypes(torch.float, torch.float16)
|
||||
def test_combine_profiles(self, device, dtype):
|
||||
"""
|
||||
@ -630,7 +647,10 @@ class TestAnalysis(TestCase):
|
||||
|
||||
# Verify device properties are present
|
||||
self.assertIn("deviceProperties", combined_profile)
|
||||
self.assertGreater(len(combined_profile["deviceProperties"]), 0)
|
||||
# XPU currently does not have the deviceProperties like CUDA.
|
||||
# See https://github.com/intel/torch-xpu-ops/issues/2247
|
||||
if torch.cuda.is_available():
|
||||
self.assertGreater(len(combined_profile["deviceProperties"]), 0)
|
||||
|
||||
# Verify some trace events from each original profile are present
|
||||
combined_event_names = {
|
||||
@ -648,7 +668,7 @@ class TestAnalysis(TestCase):
|
||||
self.assertTrue(profile3_event_names.intersection(combined_event_names))
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestAnalysis, globals())
|
||||
instantiate_device_type_tests(TestAnalysis, globals(), allow_xpu=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -2617,7 +2617,7 @@ TORCH_LIBRARY(test_autograd_cpp_node_saved_dynamic_$is_traceable, m) {
|
||||
)
|
||||
|
||||
def fn():
|
||||
for i in [10, 100, 10, 20, 10]:
|
||||
for i in [10, 30, 10, 20, 10]:
|
||||
x = torch.ones(i, i, requires_grad=True)
|
||||
out = module.custom_op_backed_by_autograd_fn(x)
|
||||
loss = out.sum()
|
||||
|
||||
@ -1988,6 +1988,20 @@ class CPUReproTests(TestCase):
|
||||
def test_tile2d_store_channel_shuffle_cl_quant_output_int8(self):
|
||||
self._test_tile2d_store_channel_shuffle_cl_quant_output_helper(torch.int8)
|
||||
|
||||
@requires_vectorization
|
||||
def test_to_channels_last_fp8(self):
|
||||
def fn(x):
|
||||
return x.to(memory_format=torch.channels_last)
|
||||
|
||||
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
self.common(
|
||||
fn,
|
||||
(torch.randn(20, 16, 48, 48).to(dtype=dtype),),
|
||||
)
|
||||
check_metrics_vec_kernel_count(2)
|
||||
|
||||
def _test_dequant_relu_quant_dequant_relu_quant_lowering_helper(self, dtype):
|
||||
def fn(
|
||||
x,
|
||||
@ -2729,6 +2743,18 @@ class CPUReproTests(TestCase):
|
||||
actual = torch.compile(op)(t)
|
||||
self.assertEqual(expected, actual)
|
||||
|
||||
def test_outer_mean_large_size(self):
|
||||
def fn(x):
|
||||
x = x.flatten()
|
||||
x_one = torch.ones_like(x)
|
||||
x = torch.outer(x, x_one)
|
||||
return torch.mean(x, dim=1)
|
||||
|
||||
x = torch.randn(2, 2, 64, 64)
|
||||
expected = fn(x)
|
||||
actual = torch.compile(fn)(x)
|
||||
self.assertEqual(expected, actual, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
|
||||
@requires_vectorization
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
|
||||
@ -42,6 +42,7 @@ def make_pallas(cls):
|
||||
cls,
|
||||
cls_prefix,
|
||||
suffix,
|
||||
(config, "cpu_backend", "pallas"),
|
||||
(config, "cuda_backend", "pallas"),
|
||||
xfail_prop="_expected_failure_pallas",
|
||||
)
|
||||
@ -336,6 +337,48 @@ class PallasTestsMixin:
|
||||
expected = operate_on_tensor(x_t_contiguous)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_strided_int_pallas(self):
|
||||
"""Test strided access patterns with the Pallas backend."""
|
||||
|
||||
def fn(x):
|
||||
# Access every other element (strided access)
|
||||
return x[::2] * 2.0
|
||||
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.arange(16, dtype=torch.float32, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_strided_offset_pallas(self):
|
||||
"""Test strided access with offset."""
|
||||
|
||||
def fn(x):
|
||||
# Access every other element starting from index 1
|
||||
return x[1::2] + 1.0
|
||||
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.arange(16, dtype=torch.float32, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_strided_2d_pallas(self):
|
||||
"""Test strided access on 2D tensors."""
|
||||
|
||||
def fn(x):
|
||||
# Simple operation on 2D tensor
|
||||
return x * 3.0
|
||||
|
||||
compiled = self._compile(fn)
|
||||
|
||||
x = torch.randn(8, 16, device=self.DEVICE)
|
||||
result = compiled(x)
|
||||
expected = fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas")
|
||||
class PallasTestsCUDA(PallasTestsMixin, TestCase):
|
||||
@ -347,14 +390,16 @@ class PallasTestsCPU(PallasTestsMixin, TestCase):
|
||||
DEVICE = "cpu"
|
||||
|
||||
|
||||
# Create test variants using the main test suite
|
||||
# Note: Only enable GPU tests since Pallas primarily targets GPU
|
||||
if hasattr(sys.modules.get(__name__), "test_torchinductor") and HAS_PALLAS:
|
||||
if getattr(test_torchinductor, "HAS_GPU", False):
|
||||
# Uncomment these to run full test suite with Pallas backend
|
||||
# make_pallas(test_torchinductor.SweepInputsGPUTest)
|
||||
# make_pallas(test_torchinductor.GPUTests)
|
||||
pass
|
||||
if test_torchinductor.HAS_CPU and HAS_PALLAS:
|
||||
make_pallas(test_torchinductor.SweepInputsCpuTest)
|
||||
# make_pallas(test_torchinductor.CpuTests)
|
||||
|
||||
|
||||
if test_torchinductor.HAS_GPU and HAS_PALLAS:
|
||||
# make_pallas(test_torchinductor.SweepInputsGPUTest)
|
||||
# make_pallas(test_torchinductor.GPUTests)
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if HAS_PALLAS:
|
||||
|
||||
@ -1309,6 +1309,40 @@ class TestCppExtensionJIT(common.TestCase):
|
||||
# test if build was successful
|
||||
self.assertEqual(success, True)
|
||||
|
||||
@unittest.skipIf(
|
||||
not IS_LINUX or not check_compiler_is_gcc(get_cxx_compiler()),
|
||||
"PCH is only available on Linux with GCC",
|
||||
)
|
||||
def test_pch_command_injection(self):
|
||||
"""Tests that PCH compilation is not vulnerable to command injection."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
exploit_file = os.path.join(tmpdir, "pch_exploit")
|
||||
# If executed by shell, this would create exploit_file
|
||||
payload = f'; echo vulnerable > "{exploit_file}"'
|
||||
cpp_source = "void foo() {}"
|
||||
|
||||
# Try to compile with malicious payload in extra_cflags
|
||||
# The compilation may succeed or fail, but the key test is whether
|
||||
# the shell command in the payload gets executed
|
||||
try:
|
||||
torch.utils.cpp_extension.load_inline(
|
||||
name="test_pch_injection",
|
||||
cpp_sources=cpp_source,
|
||||
functions=["foo"],
|
||||
extra_cflags=[payload],
|
||||
use_pch=True,
|
||||
verbose=True,
|
||||
)
|
||||
except RuntimeError:
|
||||
# Compilation failure is expected since payload is not a valid flag
|
||||
pass
|
||||
|
||||
# The critical security check: verify the shell command was NOT executed
|
||||
self.assertFalse(
|
||||
os.path.exists(exploit_file),
|
||||
"Command injection vulnerability detected!",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common.run_tests()
|
||||
|
||||
@ -10,7 +10,10 @@ from torch.fx import has_side_effect, Proxy
|
||||
from .. import graph_break_hints
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import TYPE_CHECKING, unimplemented
|
||||
from ..graph_bytecode_inputs import get_external_object_by_index
|
||||
from ..graph_bytecode_inputs import (
|
||||
get_external_object_by_index,
|
||||
register_graph_created_object,
|
||||
)
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import FxTracebackAnnotateVariable
|
||||
@ -28,6 +31,26 @@ from torch._library.custom_ops import custom_op
|
||||
Tensor = torch.Tensor
|
||||
|
||||
|
||||
def new_event(*args: Any, **kwargs: Any) -> int:
|
||||
event = torch.Event(*args, **kwargs)
|
||||
return register_graph_created_object(
|
||||
event,
|
||||
EventVariable.make_construct_in_graph_event_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def new_stream(*args: tuple[Any], **kwargs: Any) -> int:
|
||||
stream = torch.Stream(*args, **kwargs) # type: ignore[no-matching-overload,call-overload]
|
||||
return register_graph_created_object(
|
||||
stream,
|
||||
StreamVariable.make_construct_in_graph_stream_fn(
|
||||
TupleVariable([]), ConstDictVariable({})
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_stream_by_index(index: int) -> torch.Stream:
|
||||
stream = get_external_object_by_index(index)
|
||||
assert isinstance(stream, torch.Stream), (
|
||||
@ -115,6 +138,24 @@ def _(
|
||||
has_side_effect(torch.ops.streams.wait_event.default)
|
||||
|
||||
|
||||
@custom_op("streams::wait_stream", mutates_args=())
|
||||
def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None:
|
||||
waiting = _get_stream_by_index(waiting_stream_index)
|
||||
waited_on = _get_stream_by_index(waited_on_stream_index)
|
||||
waiting.wait_stream(waited_on)
|
||||
|
||||
|
||||
@wait_stream.register_fake
|
||||
def _(
|
||||
event_index: int,
|
||||
stream_index: int,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
has_side_effect(torch.ops.streams.wait_stream.default)
|
||||
|
||||
|
||||
class SymbolicStreamState:
|
||||
"""Track the currently entered stream if any"""
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from .graph_capture_wrappers import (
|
||||
handle_effect_tokens_fn,
|
||||
)
|
||||
from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationMeta
|
||||
from .streams import assign_backward_streams
|
||||
from .utils import (
|
||||
call_and_expect_output_descs,
|
||||
copy_fwd_metadata_to_bw_nodes,
|
||||
@ -473,6 +474,10 @@ def aot_dispatch_autograd_graph(
|
||||
# fw node match might be erased
|
||||
copy_fwd_metadata_to_bw_nodes(fx_g)
|
||||
|
||||
# After copying metadata, assign streams to gradient accumulation
|
||||
# nodes and insert syncrhonization
|
||||
assign_backward_streams(fx_g)
|
||||
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
if not aot_config.disable_functionalization:
|
||||
# There should be *NO* mutating ops in the graph at this point.
|
||||
|
||||
46
torch/_functorch/_aot_autograd/streams.py
Normal file
46
torch/_functorch/_aot_autograd/streams.py
Normal file
@ -0,0 +1,46 @@
|
||||
import torch.fx
|
||||
from torch._dynamo.graph_utils import _get_flat_args
|
||||
|
||||
from .utils import _is_backward_node_with_seq_nr, _is_forward_node_with_seq_nr
|
||||
|
||||
|
||||
Node = torch.fx.Node
|
||||
|
||||
|
||||
def seq_number(node: Node) -> int:
|
||||
assert "seq_nr" in node.meta, "No seq nr in seq_number call"
|
||||
return node.meta.get("seq_nr") # type: ignore[return-type]
|
||||
|
||||
|
||||
def assign_backward_streams(gm: torch.fx.GraphModule) -> None:
|
||||
"""Assigns backward streams to gradient accumulation nodes"""
|
||||
|
||||
max_fw_seq_nr = -1
|
||||
max_bw_seq_nr = -1
|
||||
bw_nodes = []
|
||||
for node in gm.graph.nodes:
|
||||
if _is_forward_node_with_seq_nr(node):
|
||||
max_fw_seq_nr = max(max_fw_seq_nr, seq_number(node))
|
||||
elif _is_backward_node_with_seq_nr(node):
|
||||
bw_nodes.append(node)
|
||||
max_bw_seq_nr = max(max_bw_seq_nr, seq_number(node))
|
||||
|
||||
if max_bw_seq_nr > max_fw_seq_nr:
|
||||
# in this case, there are some gradient accumulation nodes
|
||||
# these nodes will need stream assignments
|
||||
for node in bw_nodes:
|
||||
if seq_number(node) == max_bw_seq_nr:
|
||||
# Accumulation stream selection. Follow the rules from top to bottom to determine the accumulation stream:
|
||||
# 1. If the device of the gradient is the same as the device of the consumer,
|
||||
# then the accumulation stream is the consumer node's stream.
|
||||
# 2. If the device of the gradient matches the device of the producer,
|
||||
# then accumulation stream is the producer node's stream.
|
||||
# 3. If neither is true, pick the current stream of the device of the gradient.
|
||||
# Accumulation stream synchronization:
|
||||
# Prior to accumulation, have the accumulation stream wait for producer stream
|
||||
# and the stashed event (recorded on the previous producer stream).
|
||||
gradients = _get_flat_args(node, {})
|
||||
|
||||
|
||||
def insert_sync(producer, consumer) -> None:
|
||||
pass
|
||||
@ -2205,28 +2205,22 @@ class CppKernel(Kernel):
|
||||
reduction_size = functools.reduce(
|
||||
operator.mul, self.call_ranges[self.reduction_depth :]
|
||||
)
|
||||
if config.cpp.dynamic_threads:
|
||||
# If dynamic threads, to be conservative,
|
||||
# use reduction_size as the range size
|
||||
rt_size = reduction_size
|
||||
else:
|
||||
rt_size = CeilDiv(reduction_size, parallel_num_threads())
|
||||
|
||||
# chunk size to balance accuracy and performance
|
||||
chunk_size = 2**20
|
||||
chunk_size = 4096
|
||||
|
||||
# use acc helper If cannot get size_hint
|
||||
try:
|
||||
rt_size_hint = V.graph.sizevars.size_hint(rt_size)
|
||||
reduction_size_hint = V.graph.sizevars.size_hint(reduction_size)
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if rt_size_hint > chunk_size:
|
||||
if reduction_size_hint > chunk_size:
|
||||
# use helper if the reduction size is too large
|
||||
V.graph.sizevars.check_lt(chunk_size, rt_size)
|
||||
V.graph.sizevars.check_lt(chunk_size, reduction_size)
|
||||
return True
|
||||
else:
|
||||
V.graph.sizevars.check_leq(rt_size, chunk_size)
|
||||
V.graph.sizevars.check_leq(reduction_size, chunk_size)
|
||||
return False
|
||||
|
||||
def _acc_helper_init(
|
||||
@ -2243,7 +2237,7 @@ class CppKernel(Kernel):
|
||||
)
|
||||
num_range_thread_expr = cexpr_index(num_range_thread)
|
||||
assert reduction_type in ["welford_reduce", "sum"]
|
||||
chunk_size = 4096 if reduction_type == "welford_reduce" else 2**20
|
||||
chunk_size = 4096
|
||||
num_chunks = CeilDiv(num_range_thread, chunk_size)
|
||||
helper_type = (
|
||||
"WelfordHelper"
|
||||
@ -3690,6 +3684,8 @@ class CppTile2DKernel(CppVecKernel):
|
||||
if self.tail_size or V.graph.get_dtype(name) in DTYPE_LOWP_FP + [
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e5m2,
|
||||
]:
|
||||
line = f"{value}.store({storebuf}, {cexpr_index(self.num_elems)});"
|
||||
else:
|
||||
|
||||
@ -10,10 +10,12 @@ import torch # noqa: TC001
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .. import config
|
||||
from ..runtime.runtime_utils import torch_dtype_to_jax
|
||||
from ..utils import get_fused_kernel_name, get_kernel_metadata
|
||||
from ..virtualized import V
|
||||
from .block_analysis import BlockPatternMatcher
|
||||
from .common import BackendFeature, CSEVariable, IndentedBuffer, OpOverrides
|
||||
from .simd import SIMDKernel, SIMDScheduling
|
||||
from .simd import pexpr, SIMDKernel, SIMDScheduling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -187,63 +189,168 @@ class PallasKernelOverrides(OpOverrides):
|
||||
def where(cond: str, a: str, b: str) -> str:
|
||||
return f"jnp.where({cond}, {a}, {b})"
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(
|
||||
x: str,
|
||||
dtype: torch.dtype,
|
||||
src_dtype: Optional[torch.dtype] = None,
|
||||
use_compute_types: bool = True,
|
||||
) -> str:
|
||||
jax_dtype = torch_dtype_to_jax(dtype)
|
||||
return f"{x}.astype({jax_dtype})"
|
||||
|
||||
|
||||
class PallasKernel(SIMDKernel):
|
||||
"""
|
||||
Minimal Pallas kernel for simple elementwise operations.
|
||||
Pallas kernel for elementwise operations with support for strided/scatter access.
|
||||
|
||||
Strategy:
|
||||
- Treat loads as full-array refs: "in_ptrX[...]"
|
||||
- Convert index expressions to JAX-compatible array slicing
|
||||
- Load/store using indexed access: "in_ptrX[slice]" or full-array "in_ptrX[...]"
|
||||
- Compute expression with Python operators (compatible with jax.numpy broadcasting)
|
||||
- Store as full-array ref assignment: "out_ptrY[...] = <expr>"
|
||||
- Generate Python code that defines a Pallas kernel and a host entrypoint.
|
||||
- Use async_compile.pallas path to compile and load Python code.
|
||||
"""
|
||||
|
||||
overrides = PallasKernelOverrides # type: ignore[assignment]
|
||||
kexpr: Callable[[sympy.Expr], str] = pexpr # Use Python expression printer
|
||||
|
||||
def _get_contiguous_index_str(self, index: sympy.Expr) -> str:
|
||||
def _get_index_str(self, index: sympy.Expr) -> str:
|
||||
"""
|
||||
Validate that the index represents contiguous access and return the indexing string.
|
||||
Convert an index expression to a string suitable for Pallas indexing.
|
||||
|
||||
For Pallas, we only support simple contiguous access patterns where the index
|
||||
is a single symbol (e.g., xindex) representing a flattened iteration.
|
||||
This ensures the load/store order is contiguous.
|
||||
Pallas operates on full arrays, so we need to convert index expressions
|
||||
to JAX array slicing. For example:
|
||||
- x0 -> "..." (contiguous access, full array)
|
||||
- 2*x0 -> "::2" (strided access with stride 2)
|
||||
- 2*x0 + 1 -> "1::2" (strided access with offset 1, stride 2)
|
||||
|
||||
Args:
|
||||
index: The indexing expression to validate
|
||||
index: The indexing expression to convert
|
||||
|
||||
Returns:
|
||||
The indexing string to use (currently always "...")
|
||||
|
||||
Raises:
|
||||
Unsupported: If the index is not a simple contiguous pattern
|
||||
The indexing string to use in generated code
|
||||
"""
|
||||
# Prepare and simplify the index
|
||||
prepared_index = self.prepare_indexing(index)
|
||||
|
||||
# For contiguous access, we expect a single symbol (like xindex)
|
||||
# or a simple integer (for scalar operations)
|
||||
# For simple single-symbol access (contiguous case), we can use [...]
|
||||
# which is more efficient as it operates on the entire array at once
|
||||
if isinstance(prepared_index, sympy.Symbol):
|
||||
# This is the expected case: a single symbol representing contiguous iteration
|
||||
return "..."
|
||||
elif prepared_index.is_Integer:
|
||||
# Scalar case
|
||||
return "..."
|
||||
# Scalar index
|
||||
return str(prepared_index)
|
||||
else:
|
||||
# If there's any complex expression (ModularIndexing, FloorDiv, etc.),
|
||||
# it's not a simple contiguous pattern
|
||||
raise Unsupported(
|
||||
f"Pallas backend only supports contiguous access patterns. "
|
||||
f"Got complex index: {prepared_index}"
|
||||
)
|
||||
# Complex expression (strided/scatter access)
|
||||
# Try to extract stride and offset for common patterns
|
||||
return self._convert_to_jax_slice(prepared_index)
|
||||
|
||||
def _convert_to_jax_slice(self, index: sympy.Expr) -> str:
|
||||
"""
|
||||
Convert a sympy index expression to JAX slice notation.
|
||||
|
||||
Handles common patterns like:
|
||||
- stride*var -> ::stride
|
||||
- stride*var + offset -> offset::stride
|
||||
|
||||
For more complex patterns, falls back to explicit indexing.
|
||||
Uses BlockPatternMatcher for robust pattern matching.
|
||||
"""
|
||||
# Get the iteration variables for this kernel
|
||||
if not self.range_trees:
|
||||
return "..."
|
||||
|
||||
# Simplify the index
|
||||
index = V.graph.sizevars.simplify(index)
|
||||
free_symbols = index.free_symbols
|
||||
|
||||
# Get iteration variables from range_tree_nodes
|
||||
iter_vars = OrderedSet(self.range_tree_nodes.keys())
|
||||
|
||||
# Find which iteration variable(s) are used
|
||||
used_vars = free_symbols & iter_vars
|
||||
|
||||
if len(used_vars) == 0:
|
||||
# No iteration variables, this is a constant index
|
||||
return str(index)
|
||||
elif len(used_vars) == 1:
|
||||
# Single iteration variable - try to extract stride and offset using BlockPatternMatcher
|
||||
var = next(iter(used_vars))
|
||||
|
||||
# Get the subexpression involving this variable
|
||||
var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
|
||||
|
||||
# Try to match affine pattern: stride * var
|
||||
stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
|
||||
|
||||
if stride is not None:
|
||||
# Extract the constant offset (terms not involving var)
|
||||
offset = index - var_expr
|
||||
offset = V.graph.sizevars.simplify(offset)
|
||||
|
||||
# Generate JAX slice notation
|
||||
if stride == 1 and offset == 0:
|
||||
# Contiguous access
|
||||
return "..."
|
||||
elif offset == 0:
|
||||
# Pure stride: ::stride
|
||||
stride_str = self.kexpr(stride)
|
||||
return f"::{stride_str}"
|
||||
else:
|
||||
# Offset + stride: offset::stride
|
||||
offset_str = self.kexpr(offset)
|
||||
stride_str = self.kexpr(stride)
|
||||
return f"{offset_str}::{stride_str}"
|
||||
else:
|
||||
# Couldn't match affine pattern, fall back to original logic
|
||||
offset = index - var_expr
|
||||
offset = V.graph.sizevars.simplify(offset)
|
||||
if offset == 0 and var_expr == var:
|
||||
# Just the variable itself, unit stride
|
||||
return "..."
|
||||
elif len(used_vars) > 1:
|
||||
# Multi-dimensional indexing
|
||||
# For contiguous multi-dim access, all terms should have unit stride
|
||||
all_unit_stride = True
|
||||
for var in used_vars:
|
||||
var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var)
|
||||
stride = BlockPatternMatcher.match_affine_block_expr(var_expr, var)
|
||||
if stride != 1:
|
||||
all_unit_stride = False
|
||||
break
|
||||
|
||||
if all_unit_stride:
|
||||
# Contiguous multi-dimensional access
|
||||
return "..."
|
||||
else:
|
||||
# Strided multi-dimensional access - requires advanced indexing
|
||||
# For now, use ellipsis which may work for many cases
|
||||
# TODO: Implement proper multi-dimensional strided indexing
|
||||
return "..."
|
||||
|
||||
# For complex cases, raise an error
|
||||
return self._generate_index_array(index)
|
||||
|
||||
def _generate_index_array(self, index: sympy.Expr) -> str:
|
||||
"""
|
||||
Generate JAX code to compute an index array for complex indexing patterns.
|
||||
|
||||
For very complex patterns that can't be expressed as simple slices,
|
||||
we need to compute the indices explicitly. This is not yet fully implemented.
|
||||
"""
|
||||
# For now, raise an error for complex patterns
|
||||
# TODO: Implement advanced indexing support
|
||||
raise Unsupported(
|
||||
f"Pallas backend does not yet support complex indexing pattern: {index}"
|
||||
)
|
||||
|
||||
def load(self, name: str, index: sympy.Expr) -> CSEVariable: # type: ignore[override]
|
||||
buf = self.args.input(name)
|
||||
dtype = V.graph.get_dtype(name)
|
||||
# Validate contiguous access and get index string
|
||||
index_str = self._get_contiguous_index_str(index)
|
||||
# Pallas refs must be unpacked with [...] to load the array
|
||||
# Get index string for load operation
|
||||
index_str = self._get_index_str(index)
|
||||
# Pallas refs must be unpacked with [...] or [index] to load
|
||||
return self.cse.generate(
|
||||
self.compute,
|
||||
f"{buf}[{index_str}]",
|
||||
@ -257,9 +364,9 @@ class PallasKernel(SIMDKernel):
|
||||
raise Unsupported("pallas store mode not supported")
|
||||
out = self.args.output(name)
|
||||
self.store_buffer_names.add(name)
|
||||
# Validate contiguous access and get index string
|
||||
index_str = self._get_contiguous_index_str(index)
|
||||
# Pallas refs must use [...] assignment to store back to the ref
|
||||
# Get index string for store operation
|
||||
index_str = self._get_index_str(index)
|
||||
# Pallas refs must use [...] or [index] assignment to store
|
||||
self.stores.writeline(f"{out}[{index_str}] = {value}")
|
||||
|
||||
def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[override]
|
||||
@ -284,6 +391,11 @@ class PallasKernel(SIMDKernel):
|
||||
"Pallas backend currently supports single-output elementwise kernels only"
|
||||
)
|
||||
|
||||
# Get output dtype at compile time
|
||||
output_name = live_outs[0]
|
||||
output_dtype = V.graph.get_dtype(output_name)
|
||||
output_dtype_jax = torch_dtype_to_jax(output_dtype)
|
||||
|
||||
code = IndentedBuffer()
|
||||
code.splice(
|
||||
"""
|
||||
@ -307,7 +419,10 @@ class PallasKernel(SIMDKernel):
|
||||
)
|
||||
code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):")
|
||||
with code.indent():
|
||||
# Emit compute (CSE) and store lines; they reference *_ptr[...] directly
|
||||
# Emit compute (CSE) and store lines; they reference *_ptr[index] directly
|
||||
# The iteration variables are implicitly handled by JAX's vectorization
|
||||
# When using [...], it processes the whole array
|
||||
# When using explicit indices, they should be JAX-traced values
|
||||
for line in self.compute._lines:
|
||||
code.writeline(str(line))
|
||||
for line in self.stores._lines:
|
||||
@ -329,6 +444,9 @@ class PallasKernel(SIMDKernel):
|
||||
main_name = f"{kernel_name}_main"
|
||||
code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):")
|
||||
with code.indent():
|
||||
# Enable JAX x64 mode to support float64/int64 types
|
||||
code.writeline("# Enable JAX x64 mode for float64/int64 support")
|
||||
code.writeline("jax.config.update('jax_enable_x64', True)")
|
||||
# Identify inputs (in_ptr*) and output (out_ptr*)
|
||||
input_params = [
|
||||
p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr"))
|
||||
@ -343,24 +461,15 @@ class PallasKernel(SIMDKernel):
|
||||
output_param = output_params[0]
|
||||
|
||||
# Convert inputs to JAX arrays
|
||||
code.writeline("# Convert Torch -> JAX for inputs")
|
||||
for inp in input_params:
|
||||
code.writeline(f"{inp}_jax = jax.dlpack.from_dlpack({inp})")
|
||||
code.writeline(
|
||||
f"{inp}_jax = jax.dlpack.from_dlpack({inp}.contiguous())"
|
||||
)
|
||||
|
||||
# Get output metadata from PyTorch tensor
|
||||
code.writeline("# Prepare output metadata from PyTorch tensor")
|
||||
code.writeline("# Map PyTorch dtype to JAX dtype")
|
||||
code.writeline("_torch_dtype_to_jax = {")
|
||||
code.writeline(
|
||||
" torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16,"
|
||||
)
|
||||
code.writeline(
|
||||
" torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8,"
|
||||
)
|
||||
code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,")
|
||||
code.writeline("}")
|
||||
code.writeline(f"out_shape = tuple({output_param}.shape)")
|
||||
code.writeline(f"out_dtype = _torch_dtype_to_jax[{output_param}.dtype]")
|
||||
code.writeline(f"out_dtype = {output_dtype_jax}")
|
||||
|
||||
call_args = ["out_shape", "out_dtype"] + [
|
||||
f"{inp}_jax" for inp in input_params
|
||||
|
||||
@ -187,3 +187,31 @@ def compile_mps_shader(source: str) -> Any:
|
||||
return torch.mps.compile_shader(source)
|
||||
except SyntaxError as err:
|
||||
raise SyntaxError(f"failed to compile {source} with {err.msg}") from err
|
||||
|
||||
|
||||
def torch_dtype_to_jax(dtype: torch.dtype) -> str:
|
||||
"""
|
||||
Map PyTorch dtype to JAX dtype expression.
|
||||
|
||||
This helper is used at compile time in codegen to generate
|
||||
JAX dtype expressions for Pallas kernels.
|
||||
|
||||
Args:
|
||||
dtype: PyTorch dtype to convert
|
||||
|
||||
Returns:
|
||||
JAX dtype expression as string (e.g., "jnp.float32")
|
||||
"""
|
||||
dtype_map = {
|
||||
torch.float32: "jnp.float32",
|
||||
torch.float64: "jnp.float64",
|
||||
torch.float16: "jnp.float16",
|
||||
torch.bfloat16: "jnp.bfloat16",
|
||||
torch.int32: "jnp.int32",
|
||||
torch.int64: "jnp.int64",
|
||||
torch.int16: "jnp.int16",
|
||||
torch.int8: "jnp.int8",
|
||||
torch.uint8: "jnp.uint8",
|
||||
torch.bool: "jnp.bool_",
|
||||
}
|
||||
return dtype_map.get(dtype, f"jnp.{dtype}")
|
||||
|
||||
@ -36,6 +36,7 @@ from torch.distributed.tensor.placement_types import (
|
||||
Replicate,
|
||||
Shard,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -381,7 +382,7 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
raise AssertionError(f"Expected int, got {type(dim)}")
|
||||
if start is None:
|
||||
start = 0
|
||||
if end is None or end > input_shape[dim]:
|
||||
if end is None or statically_known_true(end > input_shape[dim]):
|
||||
end = input_shape[dim]
|
||||
if not isinstance(start, IntLike):
|
||||
raise AssertionError(f"Expected IntLike, got {type(start)}")
|
||||
@ -395,13 +396,20 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType:
|
||||
start = normalize_dim(start, input_shape[dim]) # type: ignore[arg-type]
|
||||
end = normalize_dim(end, input_shape[dim]) # type: ignore[arg-type]
|
||||
|
||||
redundant_slice = start == 0 and end == input_shape[dim] and step == 1
|
||||
statically_redundant_slice = (
|
||||
statically_known_true(start == 0)
|
||||
and statically_known_true(end == input_shape[dim])
|
||||
and statically_known_true(step == 1)
|
||||
)
|
||||
|
||||
slice_strategy = OpStrategy([])
|
||||
|
||||
for arg_strategy in input_strategy.strategies:
|
||||
arg_spec = arg_strategy.output_spec
|
||||
if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice:
|
||||
if (
|
||||
not is_tensor_dim_sharded(arg_spec, dim=slice_dim)
|
||||
or statically_redundant_slice
|
||||
):
|
||||
# only add the strategy if the slice dim is not sharded
|
||||
out_spec = DTensorSpec(mesh, arg_spec.placements)
|
||||
slice_strategy.strategies.append(
|
||||
|
||||
@ -245,7 +245,9 @@ class FxNetAccFusionsFinder:
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
def legalize_graph(
|
||||
gm: torch.fx.GraphModule, stable_topo_sort: bool = False
|
||||
) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Replace the graph of the given GraphModule with one that contains the same nodes as the
|
||||
original, but in topologically sorted order.
|
||||
@ -255,6 +257,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
|
||||
Arguments:
|
||||
gm: The graph module to topologically sort. It is modified in-place.
|
||||
stable_topo_sort: when True, PRIORITIZED_OPS would be ignored.
|
||||
|
||||
Returns:
|
||||
The graph module in-place sorted
|
||||
@ -304,7 +307,11 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
for user in cur.users:
|
||||
indeg[user] -= 1
|
||||
if indeg[user] == 0:
|
||||
if user.op == "call_function" and user.target in PRIORITIZED_OPS:
|
||||
if (
|
||||
not stable_topo_sort
|
||||
and user.op == "call_function"
|
||||
and user.target in PRIORITIZED_OPS
|
||||
):
|
||||
queue.appendleft(user)
|
||||
else:
|
||||
queue.append(user)
|
||||
|
||||
@ -220,22 +220,36 @@ def insert_subgm(
|
||||
submodule_name = sub_gm.__class__.__name__
|
||||
gm.add_submodule(submodule_name, sub_gm)
|
||||
|
||||
def last_node(target_nodes: tuple[Node, ...]) -> Node | None:
|
||||
for node in reversed(gm.graph.nodes):
|
||||
if node in target_nodes:
|
||||
return node
|
||||
return None
|
||||
|
||||
last_input_node: Node | None = last_node(orig_inputs)
|
||||
assert last_input_node is not None
|
||||
|
||||
# Create a call_module node in main graph.
|
||||
module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None)
|
||||
|
||||
output_node = sub_gm.graph.output_node()
|
||||
if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple):
|
||||
# main_remapping[comp.orig_outputs[0]] = module_node
|
||||
orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
|
||||
else:
|
||||
for i, orig_output in enumerate(orig_outputs):
|
||||
# Use Proxy to record getitem access.
|
||||
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
|
||||
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
|
||||
|
||||
module_node.meta["val"] = tuple(
|
||||
orig_output.meta.get("val", None) for orig_output in orig_outputs
|
||||
with gm.graph.inserting_after(last_input_node):
|
||||
module_node = gm.graph.call_module(
|
||||
submodule_name, args=orig_inputs, kwargs=None
|
||||
)
|
||||
output_node = sub_gm.graph.output_node()
|
||||
|
||||
next_node = module_node.next
|
||||
with gm.graph.inserting_before(next_node):
|
||||
if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple):
|
||||
# main_remapping[comp.orig_outputs[0]] = module_node
|
||||
orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
|
||||
else:
|
||||
for i, orig_output in enumerate(orig_outputs):
|
||||
# Use Proxy to record getitem access.
|
||||
proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
|
||||
orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
|
||||
|
||||
module_node.meta["val"] = tuple(
|
||||
orig_output.meta.get("val", None) for orig_output in orig_outputs
|
||||
)
|
||||
return gm
|
||||
|
||||
|
||||
@ -269,7 +283,7 @@ def fuse_by_partitions(
|
||||
|
||||
erase_nodes(gm, sorted_nodes)
|
||||
|
||||
# topological sort original gm with newly created sub_gm
|
||||
legalize_graph(gm)
|
||||
legalize_graph(gm, stable_topo_sort=True)
|
||||
gm.graph.lint()
|
||||
|
||||
return gm
|
||||
|
||||
@ -1833,7 +1833,7 @@ def _check_and_build_extension_h_precompiler_headers(
|
||||
|
||||
def build_precompile_header(pch_cmd) -> None:
|
||||
try:
|
||||
subprocess.check_output(pch_cmd, shell=True, stderr=subprocess.STDOUT)
|
||||
subprocess.check_output(shlex.split(pch_cmd), stderr=subprocess.STDOUT)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Compile PreCompile Header fail, command: {pch_cmd}") from e
|
||||
|
||||
|
||||
Reference in New Issue
Block a user