mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-12 06:44:55 +08:00
Compare commits
14 Commits
ciflow/tru
...
d4l3k/debu
| Author | SHA1 | Date | |
|---|---|---|---|
| 67f237a245 | |||
| e401a56b96 | |||
| 22650c89fb | |||
| c62a17a2fb | |||
| 713e289ae7 | |||
| 69784a0dbe | |||
| 3c2409c465 | |||
| 724cd32b0c | |||
| b62935d1a5 | |||
| ccc8c117dc | |||
| 86db4de10f | |||
| 12860892f8 | |||
| 694592ac1e | |||
| 285748e838 |
@ -337,7 +337,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
|
||||
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@ -138,7 +138,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -148,7 +149,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -158,7 +160,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,7 +127,6 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
@ -210,8 +210,12 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
||||
/test/inductor/test_flex_attention.py @drisspg
|
||||
/test/inductor/test_flex_decoding.py @drisspg
|
||||
|
||||
# Low Precision GEMMs
|
||||
# Low Precision & Grouped GEMMs
|
||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
|
||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
||||
|
||||
34
setup.py
34
setup.py
@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1647,8 +1616,6 @@ def main() -> None:
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
cmdclass,
|
||||
@ -1682,7 +1649,6 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_transformer(self):
|
||||
"""
|
||||
This tests that replicate works on a transformer model with fully_shard and replicate layers
|
||||
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Shard(dim=0),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_transformer_managed_modules(self):
|
||||
"""
|
||||
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
|
||||
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
replicate_model = replicate(replicate_model)
|
||||
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_tp_device_mesh(self):
|
||||
"""
|
||||
This tests that a user can pass in a device mesh to replicate a module
|
||||
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_train_replicate_fsdp(self):
|
||||
"""
|
||||
Tests that replicate_model has the same behavior as original model when training
|
||||
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(replicate_loss, loss)
|
||||
check_sharded_parity(self, model, replicate_model)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_train_parity_2d_mlp(self):
|
||||
"""
|
||||
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training
|
||||
|
||||
@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
|
||||
"""
|
||||
Saving a dtensor with uneven shards.
|
||||
@ -436,6 +436,7 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_checkpointable_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
@ -498,6 +499,7 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
|
||||
@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.assertEqual(cpu_model_value, meta_model_value)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
|
||||
# This test verifies that we can set model state dict by a meta device model
|
||||
# With the correlated changes in state_dict, meta device model should be accepted
|
||||
|
||||
@ -39,6 +39,7 @@ from torch.nn.modules.loss import MSELoss
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinuousTest,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
check_leaked_tensors,
|
||||
@ -231,6 +232,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_forward_only(self, ScheduleClass):
|
||||
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
|
||||
x_clone = x.clone()
|
||||
@ -274,6 +276,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_eval_inference_mode(self, ScheduleClass):
|
||||
num_microbatches = 4
|
||||
if ScheduleClass in [
|
||||
@ -351,6 +354,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_return_output(self, ScheduleClass):
|
||||
num_microbatches = 4
|
||||
if ScheduleClass in [
|
||||
@ -406,6 +410,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_multi_iter(self, ScheduleClass):
|
||||
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
chunks = 4
|
||||
@ -429,6 +434,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_kwargs_with_tracer(self, ScheduleClass):
|
||||
mod = ModelWithKwargs(d_hid, splits=self.world_size)
|
||||
mod.to(self.device)
|
||||
@ -481,6 +487,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_tracer(self, ScheduleClass):
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
|
||||
@ -523,6 +530,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@parametrize("shape_inference", [True, False])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_manual(self, ScheduleClass, shape_inference):
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
|
||||
@ -586,6 +594,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -650,6 +659,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -736,6 +746,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_v_shape_schedules(self, schedule_class):
|
||||
n_stages = 8
|
||||
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
|
||||
@ -780,6 +791,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_custom_function_callback(self):
|
||||
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
|
||||
n_stages = 8
|
||||
@ -979,6 +991,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
"ScheduleClass",
|
||||
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -1072,6 +1085,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleVShaped, ScheduleUnbalanced],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_non_symmetric_stage_ids(self, schedule_class):
|
||||
n_stages = schedule_class.n_stages
|
||||
rank_stages = schedule_class.rank_stages
|
||||
@ -1121,6 +1135,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
|
||||
n_stages = 2
|
||||
stages_per_rank = 1
|
||||
@ -1181,6 +1196,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithW])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
|
||||
n_stages = ScheduleClass.n_stages
|
||||
num_microbatches = ScheduleClass.num_microbatches
|
||||
|
||||
@ -26,6 +26,7 @@ from torch.distributed.tensor.parallel import (
|
||||
RowwiseParallel,
|
||||
SequenceParallel,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
@ -764,6 +765,7 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
self.assertEqual(grad1_norm.device_mesh, mesh_y)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_foreach_add_different_mesh(self):
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(
|
||||
|
||||
@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
|
||||
self.assertEqual(
|
||||
comm_mode.get_comm_counts(),
|
||||
{
|
||||
torch.ops.c10d_functional.all_gather_into_tensor: 4,
|
||||
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
|
||||
},
|
||||
)
|
||||
expected_cost = [
|
||||
|
||||
@ -485,7 +485,7 @@ elif TEST_XPU:
|
||||
def exit_if_lt_x_accelerators(x):
|
||||
if torch.accelerator.is_available():
|
||||
if torch.accelerator.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
|
||||
@ -2109,6 +2109,52 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
|
||||
with self.assertRaises(Unsupported):
|
||||
outer_f2(inp)
|
||||
|
||||
def test_disable_recursive_flags(self):
|
||||
class SimpleLinear(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer0 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, inp):
|
||||
return self.layer0(torch.sigmoid(inp))
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer0 = SimpleLinear()
|
||||
self.layer1 = torch.nn.Linear(4, 4)
|
||||
|
||||
def forward(self, inp):
|
||||
z = self.layer0(torch.sin(inp))
|
||||
return self.layer1(z)
|
||||
|
||||
for recursive_flag in [True, False]:
|
||||
model = SimpleModel()
|
||||
other_model = SimpleModel()
|
||||
|
||||
model.forward = torch._dynamo.disable(
|
||||
model.forward,
|
||||
recursive=recursive_flag,
|
||||
)
|
||||
self.assertEqual(
|
||||
torch._dynamo.is_dynamo_disable_recursive(model.forward),
|
||||
recursive_flag,
|
||||
)
|
||||
|
||||
other_model = torch._dynamo.disable(other_model, recursive=recursive_flag)
|
||||
self.assertEqual(
|
||||
torch._dynamo.is_dynamo_disable_recursive(
|
||||
other_model.forward
|
||||
if isinstance(other_model, torch.nn.Module)
|
||||
else other_model
|
||||
),
|
||||
recursive_flag,
|
||||
)
|
||||
|
||||
# check the model is compilable
|
||||
torch.compile(model)
|
||||
torch.compile(other_model)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -422,34 +422,41 @@ from user code:
|
||||
import optree
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
d = {"a": 1}
|
||||
optree.tree_flatten_with_path(d)
|
||||
return torch.sin(x)
|
||||
|
||||
def post_munge(s):
|
||||
s = re.sub(
|
||||
r"optree\.\S*\.flatten_with_path",
|
||||
"optree.<path>.flatten_with_path",
|
||||
s,
|
||||
)
|
||||
return re.sub(
|
||||
r"qualname: \S*flatten_with_path",
|
||||
"qualname: <path>.flatten_with_path",
|
||||
s,
|
||||
def fn1(x):
|
||||
tree = {"a": x, "b": (x - 1, 2 * x)}
|
||||
sin, cos = optree.tree_transpose_map(
|
||||
lambda t: (torch.sin(t), torch.cos(t)),
|
||||
tree,
|
||||
)
|
||||
return sin, cos
|
||||
|
||||
fn(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 1)
|
||||
fn1(torch.randn(4))
|
||||
self.assertEqual(len(counters["graph_break"]), 0)
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn2(x):
|
||||
spec = optree.treespec_deque([])
|
||||
return spec, x
|
||||
|
||||
fn2(torch.randn(4))
|
||||
self.assertGreaterEqual(len(counters["graph_break"]), 1)
|
||||
first_graph_break = next(iter(counters["graph_break"].keys()))
|
||||
|
||||
def post_munge(string):
|
||||
return re.sub(
|
||||
r"(optree\.|qualname: )\S*(\.make_from_collection)",
|
||||
r"\1<path>\2",
|
||||
string,
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
post_munge(first_graph_break),
|
||||
"""\
|
||||
Attempted to call function marked as skipped
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
|
||||
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.make_from_collection.
|
||||
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
|
||||
|
||||
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
|
||||
Developer debug context: module: optree._C, qualname: <path>.make_from_collection, skip reason: <missing reason>
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
|
||||
)
|
||||
|
||||
@ -1,154 +0,0 @@
|
||||
# Owner(s): ["module: inductor"]
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
from torch._inductor.test_case import run_tests, TestCase as InductorTestCase
|
||||
from torch._inductor.utils import ensure_cute_available
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not (ensure_cute_available() and is_datacenter_blackwell_arch()),
|
||||
"CuTeDSL library or Blackwell device not available",
|
||||
)
|
||||
@instantiate_parametrized_tests
|
||||
class TestCuTeDSLGroupedGemm(InductorTestCase):
|
||||
def _get_inputs(
|
||||
self,
|
||||
group_size: int,
|
||||
M_hint: int,
|
||||
K: int,
|
||||
N: int,
|
||||
device: str,
|
||||
dtype: torch.dtype,
|
||||
alignment: int = 16,
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
# --- Random, tile-aligned M sizes ---
|
||||
M_sizes = (
|
||||
torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int)
|
||||
* alignment
|
||||
)
|
||||
|
||||
M_total = torch.sum(M_sizes).item()
|
||||
|
||||
# --- Construct input tensors ---
|
||||
A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1
|
||||
B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
# --- Build offsets (no leading zero, strictly increasing) ---
|
||||
offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device)
|
||||
|
||||
return (A, B, offsets)
|
||||
|
||||
@parametrize("group_size", (2, 8))
|
||||
@parametrize("M_hint", (256, 1024))
|
||||
@parametrize("K", (64, 128))
|
||||
@parametrize("N", (128, 256))
|
||||
def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype)
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# Eager execution
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# Test with Cute backend
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
@parametrize("layout_A", ("contiguous", "offset", "padded", "view"))
|
||||
@parametrize("layout_B", ("contiguous", "broadcasted"))
|
||||
def test_grouped_gemm_assorted_layouts(
|
||||
self,
|
||||
layout_A: str,
|
||||
layout_B: str,
|
||||
):
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
G, K, N = 8, 64, 128
|
||||
M_sizes = [128] * G
|
||||
sum_M = sum(M_sizes)
|
||||
offsets = torch.tensor(
|
||||
[sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
A_base = torch.randn(sum_M, K, device=device, dtype=dtype)
|
||||
A = A_base
|
||||
|
||||
if layout_A == "offset":
|
||||
# allocate bigger buffer than needed, use nonzero storage offset
|
||||
storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype)
|
||||
offset = 128 # skip first 128 elements
|
||||
A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1))
|
||||
elif layout_A == "padded":
|
||||
# simulate row pitch > K (row_stride = K + pad)
|
||||
row_pitch = K + 8
|
||||
storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype)
|
||||
A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1))
|
||||
elif layout_A == "view":
|
||||
A_storage = torch.randn(sum_M * K, device=device, dtype=dtype)
|
||||
A = A_storage.view(sum_M, K)
|
||||
assert A._base is not None
|
||||
assert A.shape == (sum_M, K)
|
||||
|
||||
B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01
|
||||
|
||||
if layout_B == "broadcasted":
|
||||
# Broadcast B across groups (zero stride along G)
|
||||
B = B[0].expand(G, K, N)
|
||||
assert B.stride(0) == 0
|
||||
|
||||
def grouped_gemm_fn(A_packed, B_batched, offs):
|
||||
return torch._grouped_mm(A_packed, B_batched, offs=offs)
|
||||
|
||||
# --- eager ---
|
||||
c_eager = grouped_gemm_fn(A, B, offsets)
|
||||
|
||||
# --- compiled (CUTE backend) ---
|
||||
with config.patch(
|
||||
{
|
||||
"max_autotune": True,
|
||||
"max_autotune_gemm_backends": "CUTEDSL",
|
||||
"test_configs.autotune_choice_name_regex": "cutedsl",
|
||||
"autotune_fallback_to_aten": False,
|
||||
}
|
||||
):
|
||||
grouped_gemm_compiled = torch.compile(
|
||||
grouped_gemm_fn, backend="inductor", dynamic=False
|
||||
)
|
||||
c_compiled = grouped_gemm_compiled(A, B, offsets)
|
||||
|
||||
self.assertEqual(c_eager.dtype, dtype)
|
||||
self.assertEqual(c_compiled.dtype, dtype)
|
||||
torch.testing.assert_close(c_eager, c_compiled)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
@ -1217,6 +1217,43 @@ class TestPatternMatcher(TestCase):
|
||||
_, (code) = run_and_get_code(fn2, args[0], args[1], args[2])
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_addmm_alpha_beta_with_pointwise(self):
|
||||
# Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops
|
||||
# See https://github.com/pytorch/pytorch/issues/167313
|
||||
x = torch.rand(2, device=GPU_TYPE)
|
||||
a = torch.rand(2, 3, device=GPU_TYPE)
|
||||
b = torch.rand(3, 2, device=GPU_TYPE)
|
||||
|
||||
def f(x, a, b):
|
||||
return torch.nn.functional.relu(torch.addmm(x, a, b, alpha=0.8, beta=0.2))
|
||||
|
||||
fc = torch.compile(f)
|
||||
|
||||
expected = f(x, a, b)
|
||||
actual = fc(x, a, b)
|
||||
|
||||
# The compiled version should produce the same result as eager
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
# Verify that addmm is unfused (should not use extern_kernels.addmm)
|
||||
# The pattern should be replaced with beta * x + alpha * (a @ b)
|
||||
_, (code) = run_and_get_code(fc, x, a, b)
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
# Test with alpha=1, beta=1 (default) - should also unfuse
|
||||
def f_default(x, a, b):
|
||||
return torch.nn.functional.relu(torch.addmm(x, a, b))
|
||||
|
||||
fc_default = torch.compile(f_default)
|
||||
expected_default = f_default(x, a, b)
|
||||
actual_default = fc_default(x, a, b)
|
||||
|
||||
torch.testing.assert_close(actual_default, expected_default)
|
||||
|
||||
# Should unfuse and not use extern_kernels.addmm
|
||||
_, (code) = run_and_get_code(fc_default, x, a, b)
|
||||
FileCheck().check_not("extern_kernels.addmm(").run(code[0])
|
||||
|
||||
def test_serialized_patterns_up_to_date(self):
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._inductor.fx_passes import joint_graph
|
||||
|
||||
@ -7486,7 +7486,7 @@ class TestFXMemoryProfiler(TestCase):
|
||||
return fx_frames
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_fx_memory_profiler_augmentation(self):
|
||||
"""Test that memory snapshots are augmented with FX debug information."""
|
||||
|
||||
|
||||
@ -4251,7 +4251,7 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_stack_trace_augmentation(self):
|
||||
"""
|
||||
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
|
||||
@ -4307,7 +4307,7 @@ event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_multiple_modules(self):
|
||||
"""
|
||||
Test that multiple compiled modules under the same profiler session
|
||||
@ -4351,7 +4351,7 @@ event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
||||
@skipIfRocm
|
||||
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
|
||||
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
|
||||
def test_profiler_nested_graph_modules(self):
|
||||
"""
|
||||
Test that nested graph modules (e.g., graph modules calling subgraphs)
|
||||
|
||||
@ -165,21 +165,6 @@ def get_tests(workflow_run_id: int, workflow_run_attempt: int) -> list[dict[str,
|
||||
return flattened
|
||||
|
||||
|
||||
def get_tests_for_circleci(
|
||||
workflow_run_id: int, workflow_run_attempt: int
|
||||
) -> list[dict[str, Any]]:
|
||||
# Parse the reports and transform them to JSON
|
||||
test_cases = []
|
||||
for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
|
||||
test_cases.extend(
|
||||
parse_xml_report(
|
||||
"testcase", xml_report, workflow_run_id, workflow_run_attempt
|
||||
)
|
||||
)
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
def summarize_test_cases(test_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Group test cases by classname, file, and job_id. We perform the aggregation
|
||||
manually instead of using the `test-suite` XML tag because xmlrunner does
|
||||
@ -258,21 +243,11 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="Head repository of the workflow",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--circleci",
|
||||
action="store_true",
|
||||
help="If this is being run through circleci",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Workflow id is: {args.workflow_run_id}")
|
||||
|
||||
if args.circleci:
|
||||
test_cases = get_tests_for_circleci(
|
||||
args.workflow_run_id, args.workflow_run_attempt
|
||||
)
|
||||
else:
|
||||
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
|
||||
test_cases = get_tests(args.workflow_run_id, args.workflow_run_attempt)
|
||||
|
||||
# Flush stdout so that any errors in the upload show up last in the logs.
|
||||
sys.stdout.flush()
|
||||
|
||||
@ -100,7 +100,9 @@ class Logger:
|
||||
def _set_static_graph(self) -> None: ...
|
||||
|
||||
class _WorkerServer:
|
||||
def __init__(self, socket_path: str) -> None: ...
|
||||
port: int
|
||||
|
||||
def __init__(self, host_or_file: str, port: int = ...) -> None: ...
|
||||
def shutdown(self) -> None: ...
|
||||
|
||||
def get_debug_level(): ...
|
||||
|
||||
@ -32,6 +32,7 @@ from .decorators import (
|
||||
error_on_graph_break,
|
||||
forbid_in_graph,
|
||||
graph_break,
|
||||
is_dynamo_disable_recursive,
|
||||
mark_dynamic,
|
||||
mark_static,
|
||||
mark_static_address,
|
||||
@ -87,6 +88,7 @@ __all__ = [
|
||||
"forbid_in_graph",
|
||||
"graph_break",
|
||||
"is_compiling",
|
||||
"is_dynamo_disable_recursive",
|
||||
"list_backends",
|
||||
"lookup_backend",
|
||||
"mark_dynamic",
|
||||
|
||||
@ -15,7 +15,7 @@ import dataclasses
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
from collections import Counter
|
||||
from collections import Counter, deque
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
@ -597,32 +597,35 @@ class PyCodegen:
|
||||
|
||||
graphargs = self.tx.output.graphargs
|
||||
|
||||
seen_sources: OrderedSet[Source] = OrderedSet()
|
||||
|
||||
def collect_temp_source(source: Source) -> None:
|
||||
if source in seen_sources:
|
||||
# This source is used at least twice, so it can be reused
|
||||
self.mark_source_temp(source)
|
||||
# Dont trace source further. This prevents us from marking too
|
||||
# many nodes as temp sources.
|
||||
return
|
||||
|
||||
seen_sources.add(source)
|
||||
|
||||
def extract_nested_sources(source: Source) -> list[Source]:
|
||||
nested_sources: list[Source] = []
|
||||
if isinstance(source, ChainedSource):
|
||||
collect_temp_source(source.base)
|
||||
|
||||
nested_sources.append(source.base)
|
||||
if isinstance(source, DictGetItemSource) and isinstance(
|
||||
source.index, Source
|
||||
):
|
||||
collect_temp_source(source.index)
|
||||
nested_sources.append(source.index)
|
||||
return nested_sources
|
||||
|
||||
def collect_temp_sources(sources: deque[Source], codegen: PyCodegen) -> None:
|
||||
seen_sources: OrderedSet[Source] = OrderedSet()
|
||||
while sources:
|
||||
current_source = sources.popleft()
|
||||
if current_source in seen_sources:
|
||||
# This source is used at least twice, so it can be reused
|
||||
codegen.mark_source_temp(current_source)
|
||||
# Dont trace source further. This prevents us from marking too
|
||||
# many nodes as temp sources.
|
||||
continue
|
||||
seen_sources.add(current_source)
|
||||
sources.extend(extract_nested_sources(current_source))
|
||||
|
||||
# Collect all the sources that are used more than once, so that we can
|
||||
# generate tmp variables in the generated pre-graph bytecode. This
|
||||
# essentially implements CSE.
|
||||
for arg in graphargs:
|
||||
if arg.source is not None:
|
||||
collect_temp_source(arg.source)
|
||||
collect_temp_sources(
|
||||
deque([arg.source for arg in graphargs if arg.source is not None]), self
|
||||
)
|
||||
|
||||
cm_var = None
|
||||
if config.record_runtime_overhead:
|
||||
|
||||
@ -740,11 +740,8 @@ enable_aot_compile = False
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
# Deprecated! Please use the config in torch/fx/experimental/_config instead.
|
||||
enrich_profiler_metadata: bool = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
@ -828,6 +828,7 @@ def trace_frame(
|
||||
raise
|
||||
finally:
|
||||
tracer.output.call_cleanup_hooks()
|
||||
tracer.f_locals = {}
|
||||
|
||||
try:
|
||||
run_tracer()
|
||||
|
||||
@ -96,6 +96,7 @@ def disable(fn=None, recursive=True, *, reason=None, wrapping=True): # type: ig
|
||||
nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined]
|
||||
nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined]
|
||||
nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
||||
nonrecursive_disable_wrapper._torchdynamo_disable_recursive = False # type: ignore[attr-defined]
|
||||
# pyrefly: ignore [bad-return]
|
||||
return nonrecursive_disable_wrapper
|
||||
|
||||
@ -1023,3 +1024,13 @@ def error_on_graph_break(
|
||||
The default value of torch.compile's `error_on_graph_break` setting is False.
|
||||
"""
|
||||
return ErrorOnGraphBreakDecoratorContextManager(error_on_graph_break)
|
||||
|
||||
|
||||
def is_dynamo_disable_recursive(method: Callable[[Any], Any]) -> Optional[bool]:
|
||||
"""
|
||||
Check if a method is marked as `dynamo_disable` recursively. It returns:
|
||||
- True if disable(recursive=True)
|
||||
- False if disable(recursive=False)
|
||||
- None if method is not a disable decorator
|
||||
"""
|
||||
return getattr(method, "_torchdynamo_disable_recursive", None)
|
||||
|
||||
@ -1155,6 +1155,8 @@ class DisableContext(_TorchDynamoContext):
|
||||
# of decorators.
|
||||
_fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
|
||||
|
||||
_fn._torchdynamo_disable_recursive = True # type: ignore[attr-defined]
|
||||
|
||||
return _fn
|
||||
|
||||
def __reduce__(self) -> tuple[type[DisableContext], tuple[Any, ...]]:
|
||||
|
||||
@ -11,6 +11,16 @@ from typing import Any, TYPE_CHECKING, TypeVar
|
||||
import optree
|
||||
import optree._C
|
||||
import optree.utils
|
||||
from optree import (
|
||||
is_namedtuple,
|
||||
is_namedtuple_class,
|
||||
is_namedtuple_instance,
|
||||
is_structseq,
|
||||
is_structseq_class,
|
||||
is_structseq_instance,
|
||||
namedtuple_fields,
|
||||
structseq_fields,
|
||||
)
|
||||
|
||||
import torch.utils._cxx_pytree as cxx_pytree # noqa: F401
|
||||
from torch.utils._pytree import BUILTIN_TYPES, STANDARD_DICT_TYPES
|
||||
@ -26,7 +36,26 @@ if TYPE_CHECKING:
|
||||
from torch.utils._cxx_pytree import PyTree
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
__all__ = [
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
"namedtuple_fields",
|
||||
"structseq_fields",
|
||||
"treespec_leaf",
|
||||
"treespec_tuple",
|
||||
"treespec_dict",
|
||||
"tree_is_leaf",
|
||||
"tree_iter",
|
||||
"tree_leaves",
|
||||
"tree_flatten",
|
||||
"tree_flatten_with_path",
|
||||
"tree_structure",
|
||||
"tree_unflatten",
|
||||
]
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -48,21 +77,20 @@ def _(*args: Any, **kwargs: Any) -> bool:
|
||||
|
||||
|
||||
__name = ""
|
||||
for __name in (
|
||||
"is_namedtuple",
|
||||
"is_namedtuple_class",
|
||||
"is_namedtuple_instance",
|
||||
"is_structseq",
|
||||
"is_structseq_class",
|
||||
"is_structseq_instance",
|
||||
"namedtuple_fields",
|
||||
"structseq_fields",
|
||||
for __name, __func in (
|
||||
("is_namedtuple", is_namedtuple),
|
||||
("is_namedtuple_class", is_namedtuple_class),
|
||||
("is_namedtuple_instance", is_namedtuple_instance),
|
||||
("is_structseq", is_structseq),
|
||||
("is_structseq_class", is_structseq_class),
|
||||
("is_structseq_instance", is_structseq_instance),
|
||||
("namedtuple_fields", namedtuple_fields),
|
||||
("structseq_fields", structseq_fields),
|
||||
):
|
||||
__func = getattr(optree, __name)
|
||||
globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)(
|
||||
__func.__python_implementation__
|
||||
)
|
||||
__all__ += [__name] # noqa: PLE0604
|
||||
globals()[__name] = substitute_in_graph(
|
||||
__func, # type: ignore[arg-type]
|
||||
can_constant_fold_through=True,
|
||||
)(__func.__python_implementation__) # type: ignore[attr-defined]
|
||||
del __func
|
||||
del __name
|
||||
|
||||
@ -78,7 +106,7 @@ def tree_is_leaf(
|
||||
) -> bool:
|
||||
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)):
|
||||
return True
|
||||
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None: # type: ignore[attr-defined]
|
||||
if optree.register_pytree_node.get(type(tree), namespace=namespace) is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -113,9 +141,6 @@ def tree_iter(
|
||||
stack.extend(reversed(children))
|
||||
|
||||
|
||||
__all__ += ["tree_iter"]
|
||||
|
||||
|
||||
@substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type]
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
@ -135,9 +160,6 @@ def tree_leaves(
|
||||
)
|
||||
|
||||
|
||||
__all__ += ["tree_leaves"]
|
||||
|
||||
|
||||
class _Asterisk(str):
|
||||
__slots__ = ()
|
||||
|
||||
@ -168,7 +190,7 @@ class PyTreeSpec:
|
||||
num_leaves: int = field(init=False)
|
||||
num_children: int = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
def __post_init__(self, /) -> None:
|
||||
if self._type is None:
|
||||
assert len(self._children) == 0
|
||||
assert self._metadata is None
|
||||
@ -187,7 +209,7 @@ class PyTreeSpec:
|
||||
object.__setattr__(self, "num_leaves", num_leaves)
|
||||
object.__setattr__(self, "num_children", num_children)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
def __repr__(self, /) -> str:
|
||||
def helper(treespec: PyTreeSpec) -> str:
|
||||
if treespec.is_leaf():
|
||||
assert treespec.type is None
|
||||
@ -221,29 +243,78 @@ class PyTreeSpec:
|
||||
]
|
||||
return f"PyTreeSpec({', '.join(inner)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self, /) -> int:
|
||||
return self.num_leaves
|
||||
|
||||
@property
|
||||
def type(self) -> builtins.type | None:
|
||||
def type(self, /) -> builtins.type | None:
|
||||
return self._type
|
||||
|
||||
def is_leaf(self) -> bool:
|
||||
def is_leaf(self, /) -> bool:
|
||||
return self.num_nodes == 1 and self.num_leaves == 1
|
||||
|
||||
def children(self) -> list[PyTreeSpec]:
|
||||
def paths(self, /) -> list[tuple[Any, ...]]:
|
||||
def helper(treespec: PyTreeSpec, path_prefix: list[Any]) -> None:
|
||||
if treespec.is_leaf():
|
||||
paths.append(path_prefix)
|
||||
return
|
||||
|
||||
for entry, subspec in zip(
|
||||
treespec._entries,
|
||||
treespec._children,
|
||||
strict=True,
|
||||
):
|
||||
helper(subspec, path_prefix + [entry])
|
||||
|
||||
paths: list[list[Any]] = []
|
||||
helper(self, [])
|
||||
return [tuple(path) for path in paths]
|
||||
|
||||
def accessors(self, /) -> list[optree.PyTreeAccessor]:
|
||||
def helper(
|
||||
treespec: PyTreeSpec,
|
||||
entry_path_prefix: list[optree.PyTreeEntry],
|
||||
) -> None:
|
||||
if treespec.is_leaf():
|
||||
entry_paths.append(entry_path_prefix)
|
||||
return
|
||||
|
||||
node_type = treespec.type
|
||||
assert node_type is not None
|
||||
handler = optree.register_pytree_node.get(
|
||||
node_type, namespace=treespec.namespace
|
||||
)
|
||||
assert handler is not None
|
||||
kind: optree.PyTreeKind = handler.kind
|
||||
path_entry_type: type[optree.PyTreeEntry] = handler.path_entry_type
|
||||
|
||||
for entry, subspec in zip(
|
||||
treespec._entries,
|
||||
treespec._children,
|
||||
strict=True,
|
||||
):
|
||||
helper(
|
||||
subspec,
|
||||
entry_path_prefix + [path_entry_type(entry, node_type, kind)],
|
||||
)
|
||||
|
||||
entry_paths: list[list[optree.PyTreeEntry]] = []
|
||||
helper(self, [])
|
||||
return [optree.PyTreeAccessor(path) for path in entry_paths]
|
||||
|
||||
def children(self, /) -> list[PyTreeSpec]:
|
||||
return list(self._children)
|
||||
|
||||
def child(self, index: int) -> PyTreeSpec:
|
||||
def child(self, index: int, /) -> PyTreeSpec:
|
||||
return self._children[index]
|
||||
|
||||
def entries(self) -> list[Any]:
|
||||
def entries(self, /) -> list[Any]:
|
||||
return list(self._entries)
|
||||
|
||||
def entry(self, index: int) -> Any:
|
||||
def entry(self, index: int, /) -> Any:
|
||||
return self._entries[index]
|
||||
|
||||
def flatten_up_to(self, tree: PyTree) -> list[PyTree]:
|
||||
def flatten_up_to(self, tree: PyTree, /) -> list[PyTree]:
|
||||
def helper(
|
||||
treespec: PyTreeSpec,
|
||||
node: PyTree,
|
||||
@ -324,14 +395,14 @@ class PyTreeSpec:
|
||||
f"expected {treespec._metadata!r}, but got {metadata!r}.", # namedtuple type mismatch
|
||||
)
|
||||
|
||||
for subtree, subspec in zip(children, treespec._children):
|
||||
for subtree, subspec in zip(children, treespec._children, strict=True):
|
||||
helper(subspec, subtree, subtrees)
|
||||
|
||||
subtrees: list[PyTree] = []
|
||||
helper(self, tree, subtrees)
|
||||
return subtrees
|
||||
|
||||
def unflatten(self, leaves: Iterable[Any]) -> PyTree:
|
||||
def unflatten(self, leaves: Iterable[Any], /) -> PyTree:
|
||||
if not isinstance(leaves, (list, tuple)):
|
||||
leaves = list(leaves)
|
||||
if len(leaves) != self.num_leaves:
|
||||
@ -408,7 +479,7 @@ def treespec_tuple(
|
||||
"All children PyTreeSpecs must have the same `namespace` value "
|
||||
f"as the parent; expected {namespace!r}, got: {children!r}.",
|
||||
)
|
||||
handler = optree.register_pytree_node.get(tuple, namespace=namespace) # type: ignore[attr-defined]
|
||||
handler = optree.register_pytree_node.get(tuple, namespace=namespace)
|
||||
assert handler is not None
|
||||
return PyTreeSpec(
|
||||
tuple(children),
|
||||
@ -531,7 +602,69 @@ def tree_flatten(
|
||||
return leaves, treespec
|
||||
|
||||
|
||||
__all__ += ["tree_flatten"]
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree._C.flatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def _C_flatten(
|
||||
tree: PyTree,
|
||||
/,
|
||||
leaf_predicate: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[Any], PyTreeSpec]:
|
||||
return tree_flatten( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=leaf_predicate,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.tree_flatten_with_path,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def tree_flatten_with_path(
|
||||
tree: PyTree,
|
||||
/,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
*,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
return treespec.paths(), leaves, treespec # type: ignore[return-value]
|
||||
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree._C.flatten_with_path,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
# PyTreeSpec class defined above, not the one in the C++ module.
|
||||
can_constant_fold_through=False,
|
||||
)
|
||||
def _C_flatten_with_path(
|
||||
tree: PyTree,
|
||||
/,
|
||||
leaf_predicate: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> tuple[list[tuple[Any, ...]], list[Any], PyTreeSpec]:
|
||||
return tree_flatten_with_path( # type: ignore[return-value]
|
||||
tree,
|
||||
is_leaf=leaf_predicate,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
@ -556,9 +689,6 @@ def tree_structure(
|
||||
)[1]
|
||||
|
||||
|
||||
__all__ += ["tree_structure"]
|
||||
|
||||
|
||||
@substitute_in_graph( # type: ignore[arg-type]
|
||||
optree.tree_unflatten,
|
||||
# We need to disable constant folding here because we want the function to reference the
|
||||
@ -574,55 +704,6 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree:
|
||||
return treespec.unflatten(leaves)
|
||||
|
||||
|
||||
__all__ += ["tree_unflatten"]
|
||||
|
||||
|
||||
@substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type]
|
||||
def tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
return treespec.unflatten(map(func, *flat_args))
|
||||
|
||||
|
||||
__all__ += ["tree_map"]
|
||||
|
||||
|
||||
@substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type]
|
||||
def tree_map_(
|
||||
func: Callable[..., Any],
|
||||
tree: PyTree,
|
||||
/,
|
||||
*rests: PyTree,
|
||||
is_leaf: Callable[[PyTree], bool] | None = None,
|
||||
none_is_leaf: bool = False,
|
||||
namespace: str = "",
|
||||
) -> PyTree:
|
||||
leaves, treespec = tree_flatten(
|
||||
tree,
|
||||
is_leaf=is_leaf,
|
||||
none_is_leaf=none_is_leaf,
|
||||
namespace=namespace,
|
||||
)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return tree
|
||||
|
||||
|
||||
__all__ += ["tree_map_"]
|
||||
|
||||
_none_registration = optree.register_pytree_node.get(type(None))
|
||||
assert _none_registration is not None
|
||||
|
||||
@ -669,5 +750,5 @@ def dict_unflatten(
|
||||
) -> dict[_KT, _VT]:
|
||||
original_keys, sorted_keys = metadata
|
||||
d = dict.fromkeys(original_keys)
|
||||
d.update(zip(sorted_keys, values))
|
||||
d.update(zip(sorted_keys, values, strict=True))
|
||||
return d # type: ignore[return-value]
|
||||
|
||||
@ -550,10 +550,6 @@ max_autotune_flex_search_space: Literal["DEFAULT", "EXHAUSTIVE"] = os.environ.ge
|
||||
"TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT"
|
||||
).upper() # type: ignore[assignment]
|
||||
|
||||
cutedsl_enable_autotuning: bool = (
|
||||
os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1"
|
||||
)
|
||||
|
||||
# DEPRECATED. This setting is ignored.
|
||||
autotune_fallback_to_aten = False
|
||||
|
||||
|
||||
@ -913,6 +913,10 @@ def _get_optimization_cflags(
|
||||
if not config.is_fbcode():
|
||||
if platform.machine() == "ppc64le":
|
||||
cflags.append("mcpu=native")
|
||||
elif platform.machine() == "riscv64":
|
||||
cflags.append("march=rv64gc")
|
||||
elif platform.machine() == "riscv32":
|
||||
cflags.append("march=rv32gc")
|
||||
else:
|
||||
cflags.append("march=native")
|
||||
|
||||
|
||||
@ -1516,17 +1516,29 @@ def should_prefer_unfused_addmm(match):
|
||||
|
||||
|
||||
@register_graph_pattern(
|
||||
CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()),
|
||||
CallFunction(
|
||||
aten.addmm,
|
||||
KeywordArg("inp"),
|
||||
Arg(),
|
||||
Arg(),
|
||||
beta=KeywordArg("beta"),
|
||||
alpha=KeywordArg("alpha"),
|
||||
),
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
pass_dict=pass_patterns[2],
|
||||
extra_check=should_prefer_unfused_addmm,
|
||||
)
|
||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp):
|
||||
def repl(inp, x1, x2):
|
||||
return x1 @ x2 + inp
|
||||
def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta):
|
||||
def repl(inp, x1, x2, alpha, beta):
|
||||
mm_result = x1 @ x2
|
||||
if alpha != 1:
|
||||
mm_result = alpha * mm_result
|
||||
if beta != 1:
|
||||
inp = beta * inp
|
||||
return inp + mm_result
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
match.replace_by_example(repl, [inp, mat1, mat2])
|
||||
match.replace_by_example(repl, [inp, mat1, mat2, alpha, beta])
|
||||
|
||||
|
||||
def is_valid_addmm_fusion(match):
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -14,7 +12,6 @@ from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||
from .. import config
|
||||
from ..codegen.wrapper import PythonWrapperCodegen
|
||||
from ..ir import _IntLike, Layout, TensorBox
|
||||
from ..utils import load_template
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR)
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
||||
from torch._inductor.runtime.triton_compat import tl
|
||||
from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs
|
||||
from torch._inductor.virtualized import V
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
@ -24,13 +22,11 @@ from ..utils import (
|
||||
get_num_sms,
|
||||
has_free_symbols,
|
||||
use_aten_gemm_kernels,
|
||||
use_blackwell_cutedsl_grouped_mm,
|
||||
use_triton_template,
|
||||
)
|
||||
from .mm_common import (
|
||||
_is_static_problem,
|
||||
check_supported_striding,
|
||||
load_kernel_template,
|
||||
persistent_grouped_mm_grid,
|
||||
)
|
||||
|
||||
@ -517,11 +513,6 @@ triton_scaled_grouped_mm_template = TritonTemplate(
|
||||
source=triton_grouped_mm_source,
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template = CuteDSLTemplate(
|
||||
name="grouped_gemm_cutedsl",
|
||||
source=load_kernel_template("cutedsl_mm_grouped"),
|
||||
)
|
||||
|
||||
|
||||
def grouped_mm_args(
|
||||
mat1: TensorBox,
|
||||
@ -723,44 +714,43 @@ def _tuned_grouped_mm_common(
|
||||
# Checking only for the equality of corresponding dims of
|
||||
# multiplicands here, relying on meta function checks for
|
||||
# everything else.
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
if (
|
||||
is_nonzero
|
||||
and use_triton_template(layout)
|
||||
and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result)
|
||||
):
|
||||
scaled = scale_a is not None
|
||||
if len(m1_size) == 2:
|
||||
if len(m2_size) == 2:
|
||||
m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g = offs.get_size()[0]
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, True
|
||||
else:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = True, False
|
||||
else:
|
||||
if len(m2_size) == 2:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
g1 = offs.layout.size[0]
|
||||
g2, m, k1 = m1_size
|
||||
k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, True
|
||||
else:
|
||||
g1, m, k1 = m1_size
|
||||
g2, k2, _ = m2_size
|
||||
g = V.graph.sizevars.check_equals_and_simplify(g1, g2)
|
||||
V.graph.sizevars.check_equals(k1, k2)
|
||||
a_is_2d, b_is_2d = False, False
|
||||
|
||||
a_is_k_major = mat_a.get_stride()[-1] == 1
|
||||
b_is_k_major = mat_b.get_stride()[-2] == 1
|
||||
@ -798,22 +788,6 @@ def _tuned_grouped_mm_common(
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
if use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result
|
||||
):
|
||||
for config in get_groupgemm_configs():
|
||||
kwargs = dict(
|
||||
ACC_DTYPE="cutlass.Float32",
|
||||
)
|
||||
|
||||
cutedsl_grouped_mm_template.maybe_append_choice(
|
||||
choices,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
**kwargs,
|
||||
**asdict(config),
|
||||
)
|
||||
|
||||
input_gen_fns = {
|
||||
4: lambda x: create_offsets(
|
||||
x, m1_size, m2_size, offs.get_size() if offs is not None else None
|
||||
|
||||
@ -1,333 +0,0 @@
|
||||
import functools
|
||||
from torch._inductor.runtime.runtime_utils import ceildiv
|
||||
from cutlass.utils import TensorMapUpdateMode
|
||||
{{gen_defines()}}
|
||||
# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ----
|
||||
from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import (
|
||||
GroupedGemmKernel,
|
||||
)
|
||||
|
||||
|
||||
# Note about caching:
|
||||
# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor
|
||||
# maintains its own local caching system. At this stage, all compile-time
|
||||
# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel
|
||||
# name itself ({{kernel_name}}) are permanently baked into the file, so they
|
||||
# do not need to be included in any cache key.
|
||||
#
|
||||
# The caching mechanism is split into two levels:
|
||||
#
|
||||
# 1. prep_cache
|
||||
# Caches the compiled executor for build_group_ptrs_from_bases(). This
|
||||
# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C,
|
||||
# and can therefore be safely reused across runs with different group
|
||||
# partitioning (`offs`).
|
||||
#
|
||||
# 2. gemm_cache
|
||||
# Caches the compiled Grouped GEMM executor. Its key extends the prep
|
||||
# cache key with hardware- and grid-specific parameters:
|
||||
# (prep_cache_key, max_active_clusters, total_num_clusters).
|
||||
# This is necessary because different `offs` tensors can change the
|
||||
# per-group problem sizes and thus alter `total_num_clusters`, which in
|
||||
# turn changes the grid shape and persistent scheduler configuration.
|
||||
# Kernels compiled for one grid cannot be safely reused for another.
|
||||
#
|
||||
#
|
||||
# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically,
|
||||
# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead,
|
||||
# despite depending only on the GPU type. We cache this function to mitigate
|
||||
# redundant recompiles even when shape/stride/dtype cache misses force kernel
|
||||
# regeneration. A follow-up study will investigate the root cause.
|
||||
|
||||
prep_cache = {}
|
||||
gemm_cache = {}
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_hardware_info():
|
||||
hw = cutlass.utils.HardwareInfo()
|
||||
sm_count = hw.get_max_active_clusters(1)
|
||||
max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N)
|
||||
|
||||
return (sm_count, max_active_clusters)
|
||||
|
||||
|
||||
def get_prep_cache_key(input_a, input_b, output):
|
||||
"""
|
||||
Returns a tuple key for caching the preprocessing kernel executor based on kernel name,
|
||||
shapes, strides, and dtypes of input/output tensors.
|
||||
"""
|
||||
return (
|
||||
tuple(input_a.shape),
|
||||
tuple(input_a.stride()),
|
||||
input_a.dtype,
|
||||
tuple(input_b.shape),
|
||||
tuple(input_b.stride()),
|
||||
input_b.dtype,
|
||||
tuple(output.shape),
|
||||
tuple(output.stride()),
|
||||
output.dtype,
|
||||
)
|
||||
|
||||
|
||||
def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters):
|
||||
"""
|
||||
Returns a tuple key for caching the gemm kernel executor by extending the
|
||||
prep cache key with hardware- and grid-specific parameters.
|
||||
"""
|
||||
return (
|
||||
prep_cache_key,
|
||||
max_active_clusters,
|
||||
total_num_clusters,
|
||||
)
|
||||
|
||||
|
||||
@cute.kernel
|
||||
def build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64: cutlass.Int64, # device addr of input_a (bytes)
|
||||
base_B_u64: cutlass.Int64, # device addr of input_b (bytes)
|
||||
base_C_u64: cutlass.Int64, # device addr of Output (bytes)
|
||||
offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Int32, # bytes
|
||||
# -------- STRIDES (in ELEMENTS) --------
|
||||
stride_A_m_elems: cutlass.Constexpr, # A.stride(0)
|
||||
stride_A_k_elems: cutlass.Constexpr, # A.stride(1)
|
||||
stride_B0_elems: cutlass.Constexpr, # B.stride(0)
|
||||
stride_Bk_elems: cutlass.Constexpr, # B.stride(1)
|
||||
stride_Bn_elems: cutlass.Constexpr, # B.stride(2)
|
||||
stride_C_m_elems: cutlass.Constexpr, # C.stride(0)
|
||||
stride_C_n_elems: cutlass.Constexpr, # C.stride(1)
|
||||
# -------- OUTPUTS --------
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr)
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1)
|
||||
out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]]
|
||||
):
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
g = tidx
|
||||
|
||||
m_beg_i32 = 0
|
||||
if g > 0:
|
||||
m_beg_i32 = offs[g - 1]
|
||||
m_end_i32 = offs[g]
|
||||
m_g_i32 = m_end_i32 - m_beg_i32
|
||||
|
||||
a_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
c_byte_off = (
|
||||
cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element)
|
||||
)
|
||||
b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element)
|
||||
|
||||
# ---- pointers ----
|
||||
out_ptrs[g, 0] = base_A_u64 + a_byte_off
|
||||
out_ptrs[g, 1] = base_B_u64 + b_byte_off
|
||||
out_ptrs[g, 2] = base_C_u64 + c_byte_off
|
||||
|
||||
# ---- (m, n, k, 1) ----
|
||||
out_problem[g, 0] = m_g_i32
|
||||
out_problem[g, 1] = N
|
||||
out_problem[g, 2] = K
|
||||
out_problem[g, 3] = cutlass.Int32(1)
|
||||
|
||||
# ---- strides ----
|
||||
out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems)
|
||||
out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems)
|
||||
out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems)
|
||||
out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems)
|
||||
out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems)
|
||||
out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems)
|
||||
|
||||
|
||||
@cute.jit
|
||||
def launch_build_group_ptrs_from_bases(
|
||||
base_A_u64: cutlass.Int64,
|
||||
base_B_u64: cutlass.Int64,
|
||||
base_C_u64: cutlass.Int64,
|
||||
offs: cute.Tensor,
|
||||
G: cutlass.Constexpr,
|
||||
K: cutlass.Constexpr,
|
||||
N: cutlass.Constexpr,
|
||||
sizeof_element: cutlass.Constexpr,
|
||||
stride_A_m_elems: cutlass.Constexpr,
|
||||
stride_A_k_elems: cutlass.Constexpr,
|
||||
stride_B0_elems: cutlass.Constexpr,
|
||||
stride_Bk_elems: cutlass.Constexpr,
|
||||
stride_Bn_elems: cutlass.Constexpr,
|
||||
stride_C_m_elems: cutlass.Constexpr,
|
||||
stride_C_n_elems: cutlass.Constexpr,
|
||||
out_ptrs: cute.Tensor, # [G,3] cutlass.Int64
|
||||
out_problem: cute.Tensor, # [G,4] cutlass.Int32
|
||||
out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
build_group_ptrs_from_bases_kernel(
|
||||
base_A_u64,
|
||||
base_B_u64,
|
||||
base_C_u64,
|
||||
offs,
|
||||
K,
|
||||
N,
|
||||
sizeof_element,
|
||||
stride_A_m_elems,
|
||||
stride_A_k_elems,
|
||||
stride_B0_elems,
|
||||
stride_Bk_elems,
|
||||
stride_Bn_elems,
|
||||
stride_C_m_elems,
|
||||
stride_C_n_elems,
|
||||
out_ptrs,
|
||||
out_problem,
|
||||
out_strides_abc,
|
||||
).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream)
|
||||
|
||||
|
||||
{{def_kernel("input_a", "input_b", "input_a_offs")}}
|
||||
stream = cuda.CUstream(stream)
|
||||
|
||||
input_b = input_b.transpose(1, 2)
|
||||
|
||||
sumM, K = input_a.shape
|
||||
G, N, Kb = input_b.shape
|
||||
|
||||
dev = input_a.device
|
||||
|
||||
base_A_u64 = int(input_a.data_ptr())
|
||||
base_B_u64 = int(input_b.data_ptr())
|
||||
base_C_u64 = int({{get_output()}}.data_ptr())
|
||||
|
||||
ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64)
|
||||
probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32)
|
||||
strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32)
|
||||
ptrs = from_dlpack(ptrs_t)
|
||||
probs = from_dlpack(probs_t)
|
||||
strides = from_dlpack(strides_t)
|
||||
|
||||
prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}})
|
||||
prep_executor = prep_cache.get(prep_cache_key)
|
||||
|
||||
if prep_executor is None:
|
||||
sizeof_element = int(input_a.element_size())
|
||||
sA_m, sA_k = map(int, input_a.stride())
|
||||
sB_0, sB_n, sB_k = map(int, input_b.stride())
|
||||
sC_m, sC_n = map(int, {{get_output()}}.stride())
|
||||
|
||||
prep_executor = cute.compile(
|
||||
launch_build_group_ptrs_from_bases,
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
G=int(G),
|
||||
K=int(K),
|
||||
N=int(N),
|
||||
sizeof_element=sizeof_element,
|
||||
stride_A_m_elems=sA_m,
|
||||
stride_A_k_elems=sA_k,
|
||||
stride_B0_elems=sB_0,
|
||||
stride_Bk_elems=sB_k,
|
||||
stride_Bn_elems=sB_n,
|
||||
stride_C_m_elems=sC_m,
|
||||
stride_C_n_elems=sC_n,
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
prep_cache[prep_cache_key] = prep_executor
|
||||
|
||||
prep_executor(
|
||||
base_A_u64=base_A_u64,
|
||||
base_B_u64=base_B_u64,
|
||||
base_C_u64=base_C_u64,
|
||||
offs=from_dlpack(input_a_offs),
|
||||
out_ptrs=ptrs,
|
||||
out_problem=probs,
|
||||
out_strides_abc=strides,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# --- Tensormap workspace per SM ---
|
||||
num_tensormap_buffers, max_active_clusters = get_hardware_info()
|
||||
tensormap_shape = (
|
||||
num_tensormap_buffers,
|
||||
GroupedGemmKernel.num_tensormaps,
|
||||
GroupedGemmKernel.bytes_per_tensormap // 8,
|
||||
)
|
||||
tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64)
|
||||
tensormap_workspace = from_dlpack(tensormap_workspace_t)
|
||||
|
||||
# --- Total clusters ---
|
||||
def compute_total_num_clusters(
|
||||
problem_sizes_mnkl,
|
||||
cluster_tile_shape_mn,
|
||||
):
|
||||
total_num_clusters = 0
|
||||
for m, n, _, _ in problem_sizes_mnkl:
|
||||
num_clusters_mn = tuple(
|
||||
ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn)
|
||||
)
|
||||
total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn)
|
||||
return total_num_clusters
|
||||
|
||||
# Compute cluster tile shape
|
||||
def compute_cluster_tile_shape(
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
use_2cta_instrs,
|
||||
):
|
||||
cta_tile_shape_mn = list(mma_tiler_mn)
|
||||
if use_2cta_instrs:
|
||||
cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2
|
||||
return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn))
|
||||
|
||||
cluster_tile_shape_mn = compute_cluster_tile_shape(
|
||||
(TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA)
|
||||
)
|
||||
|
||||
total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn))
|
||||
|
||||
gemm_cache_key = get_gemm_cache_key(
|
||||
prep_cache_key, max_active_clusters, total_num_clusters
|
||||
)
|
||||
gemm_executor = gemm_cache.get(gemm_cache_key)
|
||||
|
||||
if gemm_executor is None:
|
||||
grouped_gemm = GroupedGemmKernel(
|
||||
acc_dtype=ACC_DTYPE,
|
||||
use_2cta_instrs=USE_2_CTA,
|
||||
mma_tiler_mn=(TILE_M, TILE_N),
|
||||
cluster_shape_mn=(CLUSTER_M, CLUSTER_N),
|
||||
tensormap_update_mode=TENSORMAP_UPDATE_MODE,
|
||||
)
|
||||
|
||||
gemm_executor = cute.compile(
|
||||
grouped_gemm,
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
G,
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
total_num_clusters,
|
||||
tensormap_workspace,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
gemm_cache[gemm_cache_key] = gemm_executor
|
||||
|
||||
gemm_executor(
|
||||
from_dlpack(input_a.unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16),
|
||||
from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16),
|
||||
probs,
|
||||
strides,
|
||||
ptrs,
|
||||
tensormap_workspace,
|
||||
stream,
|
||||
)
|
||||
@ -1,141 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from itertools import product
|
||||
|
||||
import torch._inductor.config as config
|
||||
|
||||
|
||||
class TensorMapUpdateMode(Enum):
|
||||
"""Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency."""
|
||||
|
||||
SMEM = auto()
|
||||
GMEM = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CuTeGemmConfig:
|
||||
TILE_M: int = 128
|
||||
TILE_N: int = 192
|
||||
CLUSTER_M: int = 2
|
||||
CLUSTER_N: int = 1
|
||||
USE_2_CTA: bool = False
|
||||
TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM
|
||||
|
||||
|
||||
def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
For information regarding valid config sets, see:
|
||||
https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py
|
||||
"""
|
||||
|
||||
# Tile_n is always the same regardless of 2cta
|
||||
tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
# Valid clusters
|
||||
clusters_no_2cta = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(1, 8),
|
||||
(1, 16),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
clusters_2cta = [
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(2, 8),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
(8, 1),
|
||||
(8, 2),
|
||||
(16, 1),
|
||||
]
|
||||
|
||||
configs: list[CuTeGemmConfig] = []
|
||||
|
||||
for use_2cta, cluster_set, tile_m_range in [
|
||||
(False, clusters_no_2cta, [64, 128]),
|
||||
(True, clusters_2cta, [128, 256]),
|
||||
]:
|
||||
for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product(
|
||||
[TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM],
|
||||
tile_m_range,
|
||||
tile_n_vals,
|
||||
cluster_set,
|
||||
):
|
||||
configs.append(
|
||||
CuTeGemmConfig(
|
||||
tile_m,
|
||||
tile_n,
|
||||
cluster_m,
|
||||
cluster_n,
|
||||
USE_2_CTA=use_2cta,
|
||||
TENSORMAP_UPDATE_MODE=tensormap_update_mode,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def get_default_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
"""
|
||||
|
||||
config_tuples = [
|
||||
(128, 256, 2, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 160, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.GMEM),
|
||||
(128, 256, 1, 2, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 256, 2, 1, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 8, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(64, 32, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(256, 192, 2, 1, True, TensorMapUpdateMode.GMEM),
|
||||
(256, 256, 2, 2, True, TensorMapUpdateMode.SMEM),
|
||||
(128, 96, 1, 2, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.SMEM),
|
||||
(64, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 192, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(128, 64, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 160, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
(64, 256, 1, 1, False, TensorMapUpdateMode.GMEM),
|
||||
]
|
||||
|
||||
return [CuTeGemmConfig(*args) for args in config_tuples]
|
||||
|
||||
|
||||
def get_groupgemm_configs() -> list[CuTeGemmConfig]:
|
||||
"""
|
||||
Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel.
|
||||
|
||||
Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures
|
||||
or unstable results. By default, autotuning is disabled and we return only
|
||||
a single baseline config.
|
||||
"""
|
||||
if (
|
||||
config.cutedsl_enable_autotuning
|
||||
and config.max_autotune_gemm_search_space == "EXHAUSTIVE"
|
||||
):
|
||||
return get_exhaustive_groupgemm_configs()
|
||||
elif config.cutedsl_enable_autotuning:
|
||||
return get_default_groupgemm_configs()
|
||||
else:
|
||||
return [get_default_groupgemm_configs()[0]]
|
||||
@ -1911,84 +1911,6 @@ def use_triton_blackwell_tma_template(
|
||||
return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def ensure_cute_available() -> bool:
|
||||
"""Check if CuTeDSL is importable; cache the result for reuse.
|
||||
|
||||
Call ensure_cute_available.cache_clear() after installing CuTeDSL
|
||||
in the same interpreter to retry the import.
|
||||
"""
|
||||
try:
|
||||
return importlib.util.find_spec("cutlass.cute") is not None
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def use_blackwell_cutedsl_grouped_mm(
|
||||
mat_a: Any,
|
||||
mat_b: Any,
|
||||
layout: Layout,
|
||||
a_is_2d: bool,
|
||||
b_is_2d: bool,
|
||||
offs: Optional[Any],
|
||||
bias: Optional[Any],
|
||||
scale_result: Optional[Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we can use the blackwell kernel for grouped mm.
|
||||
Required conditions:
|
||||
1. CuTeDSL backend is enabled
|
||||
2. CuTeDSL is available
|
||||
3. We are on a blackwell arch
|
||||
4. The dtype is bf16
|
||||
5. Max autotune or max autotune gemm is enabled
|
||||
6. A, B, and the output are 16B aligned
|
||||
7. We are not using dynamic shapes
|
||||
8. A is 2d
|
||||
9. B is 3d
|
||||
10. Offsets are provided
|
||||
11. Bias and Scale are not provided
|
||||
"""
|
||||
if not ensure_cute_available():
|
||||
return False
|
||||
|
||||
if not _use_autotune_backend("CUTEDSL"):
|
||||
return False
|
||||
|
||||
from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
|
||||
|
||||
if not is_gpu(layout.device.type):
|
||||
return False
|
||||
|
||||
if not is_datacenter_blackwell_arch():
|
||||
return False
|
||||
|
||||
layout_dtypes = [torch.bfloat16]
|
||||
if not _use_template_for_gpu(layout, layout_dtypes):
|
||||
return False
|
||||
|
||||
if not (config.max_autotune or config.max_autotune_gemm):
|
||||
return False
|
||||
|
||||
# Checks for 16B ptr and stride alignment
|
||||
if not can_use_tma(mat_a, mat_b, output_layout=layout):
|
||||
return False
|
||||
|
||||
if any(is_dynamic(x) for x in [mat_a, mat_b]):
|
||||
return False
|
||||
|
||||
if not a_is_2d or b_is_2d:
|
||||
return False
|
||||
|
||||
if offs is None:
|
||||
return False
|
||||
|
||||
if bias is not None or scale_result is not None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
#include <torch/csrc/distributed/c10d/control_plane/Handlers.hpp>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/FlightRecorder.hpp>
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
@ -63,6 +65,14 @@ RegisterHandler pingHandler{"ping", [](const Request&, Response& res) {
|
||||
res.setStatus(200);
|
||||
}};
|
||||
|
||||
RegisterHandler frTracehandler(
|
||||
"fr_trace_json",
|
||||
[](const Request&, Response& res) {
|
||||
auto trace = ::c10d::dump_fr_trace_json(true, true);
|
||||
res.setContent(std::move(trace), "application/json");
|
||||
res.setStatus(200);
|
||||
});
|
||||
|
||||
} // namespace
|
||||
|
||||
void registerHandler(const std::string& name, HandlerFunc f) {
|
||||
|
||||
@ -152,11 +152,17 @@ WorkerServer::WorkerServer(const std::string& hostOrFile, int port) {
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, 80),
|
||||
fmt::format("Error binding to {}", hostOrFile));
|
||||
} else if (port == 0) {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
port_ = server_.bind_to_any_port(hostOrFile);
|
||||
TORCH_CHECK(
|
||||
port_ >= 0, fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
} else {
|
||||
C10D_WARNING("Server listening to TCP {}:{}", hostOrFile, port);
|
||||
TORCH_CHECK(
|
||||
server_.bind_to_port(hostOrFile, port),
|
||||
fmt::format("Error binding to {}:{}", hostOrFile, port));
|
||||
port_ = port;
|
||||
}
|
||||
|
||||
serverThread_ = std::thread([this]() {
|
||||
|
||||
@ -19,9 +19,14 @@ class TORCH_API WorkerServer : public c10::intrusive_ptr_target {
|
||||
|
||||
void shutdown();
|
||||
|
||||
int port() {
|
||||
return port_;
|
||||
}
|
||||
|
||||
private:
|
||||
httplib::Server server_;
|
||||
std::thread serverThread_;
|
||||
int port_;
|
||||
};
|
||||
|
||||
} // namespace c10d::control_plane
|
||||
|
||||
@ -46,6 +46,7 @@
|
||||
|
||||
#include <fmt/format.h>
|
||||
#include <pybind11/chrono.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp>
|
||||
#include <torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp>
|
||||
@ -381,8 +382,9 @@ void _register_comm_hook(
|
||||
::c10d::Reducer& reducer,
|
||||
py::object state,
|
||||
py::object comm_hook) {
|
||||
reducer.register_comm_hook(std::make_unique<::c10d::PythonCommHook>(
|
||||
std::move(state), std::move(comm_hook)));
|
||||
reducer.register_comm_hook(
|
||||
std::make_unique<::c10d::PythonCommHook>(
|
||||
std::move(state), std::move(comm_hook)));
|
||||
}
|
||||
|
||||
// Called from DDP's Python API to create a c10d C++ comm hook.
|
||||
@ -882,37 +884,39 @@ This class does not support ``__members__`` property.)");
|
||||
[](const ::c10d::ReduceOp& self, const py::dict& memo) {
|
||||
return ::c10d::ReduceOp(self);
|
||||
})
|
||||
.def(py::pickle(
|
||||
[](const ::c10d::ReduceOp& r) {
|
||||
// __getstate__
|
||||
if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
|
||||
return py::make_tuple(r.op_, py::none());
|
||||
}
|
||||
TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
|
||||
const auto* preMulSupplement =
|
||||
reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
|
||||
r.supplement_.get());
|
||||
if (!preMulSupplement->tensor_factor.defined()) {
|
||||
return py::make_tuple(r.op_, preMulSupplement->double_factor);
|
||||
} else {
|
||||
return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
|
||||
}
|
||||
},
|
||||
[](const py::tuple& t) {
|
||||
// __setstate__
|
||||
TORCH_CHECK(t.size() == 2, "Invalid state");
|
||||
const auto op =
|
||||
static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast<uint8_t>());
|
||||
if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
|
||||
return ::c10d::ReduceOp(op);
|
||||
}
|
||||
const auto preMulSupplement_factor = t[1];
|
||||
if (py::isinstance<py::float_>(preMulSupplement_factor)) {
|
||||
return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
|
||||
} else {
|
||||
return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
|
||||
}
|
||||
}));
|
||||
.def(
|
||||
py::pickle(
|
||||
[](const ::c10d::ReduceOp& r) {
|
||||
// __getstate__
|
||||
if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
|
||||
return py::make_tuple(r.op_, py::none());
|
||||
}
|
||||
TORCH_CHECK(
|
||||
r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp");
|
||||
const auto* preMulSupplement =
|
||||
reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>(
|
||||
r.supplement_.get());
|
||||
if (!preMulSupplement->tensor_factor.defined()) {
|
||||
return py::make_tuple(r.op_, preMulSupplement->double_factor);
|
||||
} else {
|
||||
return py::make_tuple(r.op_, preMulSupplement->tensor_factor);
|
||||
}
|
||||
},
|
||||
[](const py::tuple& t) {
|
||||
// __setstate__
|
||||
TORCH_CHECK(t.size() == 2, "Invalid state");
|
||||
const auto op = static_cast<::c10d::ReduceOp::RedOpType>(
|
||||
t[0].cast<uint8_t>());
|
||||
if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) {
|
||||
return ::c10d::ReduceOp(op);
|
||||
}
|
||||
const auto preMulSupplement_factor = t[1];
|
||||
if (py::isinstance<py::float_>(preMulSupplement_factor)) {
|
||||
return ::c10d::makeNCCLPreMulSum(t[1].cast<double>());
|
||||
} else {
|
||||
return ::c10d::makeNCCLPreMulSum(t[1].cast<at::Tensor>());
|
||||
}
|
||||
}));
|
||||
|
||||
py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType")
|
||||
.value("SUM", ::c10d::ReduceOp::RedOpType::SUM)
|
||||
@ -3579,10 +3583,11 @@ Example::
|
||||
[](std::optional<bool> includeCollectives,
|
||||
std::optional<bool> includeStackTraces,
|
||||
std::optional<bool> onlyActive) {
|
||||
return py::bytes(::c10d::dump_xccl_trace(
|
||||
includeCollectives.value_or(true),
|
||||
includeStackTraces.value_or(true),
|
||||
onlyActive.value_or(false)));
|
||||
return py::bytes(
|
||||
::c10d::dump_xccl_trace(
|
||||
includeCollectives.value_or(true),
|
||||
includeStackTraces.value_or(true),
|
||||
onlyActive.value_or(false)));
|
||||
},
|
||||
py::arg("includeCollectives") = std::optional<bool>(),
|
||||
py::arg("includeStackTraces") = std::optional<bool>(),
|
||||
@ -4112,8 +4117,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
"_dump_nccl_trace_json",
|
||||
[](std::optional<bool> includeCollectives,
|
||||
std::optional<bool> onlyActive) {
|
||||
return py::bytes(::c10d::dump_nccl_trace_json(
|
||||
includeCollectives.value_or(true), onlyActive.value_or(false)));
|
||||
return py::bytes(
|
||||
::c10d::dump_nccl_trace_json(
|
||||
includeCollectives.value_or(true), onlyActive.value_or(false)));
|
||||
},
|
||||
py::arg("includeCollectives") = std::optional<bool>(),
|
||||
py::arg("onlyActive") = std::optional<bool>(),
|
||||
@ -4130,10 +4136,11 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
[](std::optional<bool> includeCollectives,
|
||||
std::optional<bool> includeStackTraces,
|
||||
std::optional<bool> onlyActive) {
|
||||
return py::bytes(::c10d::dump_nccl_trace(
|
||||
includeCollectives.value_or(true),
|
||||
includeStackTraces.value_or(true),
|
||||
onlyActive.value_or(false)));
|
||||
return py::bytes(
|
||||
::c10d::dump_nccl_trace(
|
||||
includeCollectives.value_or(true),
|
||||
includeStackTraces.value_or(true),
|
||||
onlyActive.value_or(false)));
|
||||
},
|
||||
py::arg("includeCollectives") = std::optional<bool>(),
|
||||
py::arg("includeStackTraces") = std::optional<bool>(),
|
||||
@ -4157,8 +4164,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
"_dump_fr_trace_json",
|
||||
[](std::optional<bool> includeCollectives,
|
||||
std::optional<bool> onlyActive) {
|
||||
return py::bytes(::c10d::dump_fr_trace_json(
|
||||
includeCollectives.value_or(true), onlyActive.value_or(false)));
|
||||
return py::bytes(
|
||||
::c10d::dump_fr_trace_json(
|
||||
includeCollectives.value_or(true), onlyActive.value_or(false)));
|
||||
},
|
||||
py::arg("includeCollectives") = std::optional<bool>(),
|
||||
py::arg("onlyActive") = std::optional<bool>(),
|
||||
@ -4175,10 +4183,11 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
[](std::optional<bool> includeCollectives,
|
||||
std::optional<bool> includeStackTraces,
|
||||
std::optional<bool> onlyActive) {
|
||||
return py::bytes(::c10d::dump_fr_trace(
|
||||
includeCollectives.value_or(true),
|
||||
includeStackTraces.value_or(true),
|
||||
onlyActive.value_or(false)));
|
||||
return py::bytes(
|
||||
::c10d::dump_fr_trace(
|
||||
includeCollectives.value_or(true),
|
||||
includeStackTraces.value_or(true),
|
||||
onlyActive.value_or(false)));
|
||||
},
|
||||
py::arg("includeCollectives") = std::optional<bool>(),
|
||||
py::arg("includeStackTraces") = std::optional<bool>(),
|
||||
@ -4203,7 +4212,9 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
}),
|
||||
py::arg("host_or_file"),
|
||||
py::arg("port") = -1)
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown);
|
||||
.def("shutdown", &::c10d::control_plane::WorkerServer::shutdown)
|
||||
.def_property_readonly(
|
||||
"port", &::c10d::control_plane::WorkerServer::port);
|
||||
|
||||
module.def(
|
||||
"_get_handler",
|
||||
@ -4219,6 +4230,25 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
Returns the handler with the specified name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_register_handler",
|
||||
[](const std::string& name, py::function handler) {
|
||||
::c10d::control_plane::registerHandler(
|
||||
name,
|
||||
[handler](
|
||||
const ::c10d::control_plane::Request& req,
|
||||
::c10d::control_plane::Response& res) {
|
||||
py::gil_scoped_acquire acquire;
|
||||
handler(std::ref(req), std::ref(res));
|
||||
});
|
||||
},
|
||||
|
||||
py::arg("name"),
|
||||
py::arg("handler"),
|
||||
R"(
|
||||
Registers a handler by name.
|
||||
)");
|
||||
|
||||
module.def(
|
||||
"_get_handler_names",
|
||||
&::c10d::control_plane::getHandlerNames,
|
||||
@ -4238,10 +4268,7 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
.def("body", &::c10d::control_plane::Request::body)
|
||||
.def("params", &::c10d::control_plane::Request::params);
|
||||
|
||||
py::class_<
|
||||
::c10d::control_plane::Response,
|
||||
std::shared_ptr<::c10d::control_plane::Response>,
|
||||
PythonResponse>(
|
||||
py::class_<::c10d::control_plane::Response, PythonResponse>(
|
||||
module,
|
||||
"_Response",
|
||||
R"(
|
||||
@ -4267,9 +4294,10 @@ such as `dist.all_reduce(tensor, async_op=True)`.
|
||||
} // namespace
|
||||
|
||||
// c10d methods on torch._C
|
||||
static PyMethodDef methods[] = { // NOLINT
|
||||
{"_c10d_init", c10d_init, METH_NOARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
static PyMethodDef methods[] =
|
||||
{ // NOLINT
|
||||
{"_c10d_init", c10d_init, METH_NOARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
// NOLINTNEXTLINE(misc-use-internal-linkage)
|
||||
PyMethodDef* python_functions() {
|
||||
|
||||
239
torch/distributed/debug.py
Normal file
239
torch/distributed/debug.py
Normal file
@ -0,0 +1,239 @@
|
||||
import os
|
||||
import socket
|
||||
import multiprocessing
|
||||
import requests
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
import time
|
||||
import tempfile
|
||||
|
||||
from flask import Flask
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.profiler import (
|
||||
profile,
|
||||
ProfilerActivity,
|
||||
record_function,
|
||||
_ExperimentalConfig,
|
||||
)
|
||||
from torch._C._distributed_c10d import _WorkerServer, _register_handler
|
||||
|
||||
|
||||
def _torch_profile(req, resp):
|
||||
experimental_config = _ExperimentalConfig(
|
||||
profile_all_threads=True,
|
||||
)
|
||||
with profile(record_shapes=True, experimental_config=experimental_config) as prof:
|
||||
time.sleep(2)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="torch_debug", suffix=".json") as f:
|
||||
prof.export_chrome_trace(f.name)
|
||||
resp.set_content(open(f.name, "rb").read(), "application/json")
|
||||
resp.set_status(200)
|
||||
|
||||
|
||||
_register_handler("torch_profile", _torch_profile)
|
||||
|
||||
MASTER_ADDR = os.environ["MASTER_ADDR"]
|
||||
MASTER_PORT = int(os.environ["MASTER_PORT"])
|
||||
RANK = int(os.environ["RANK"])
|
||||
WORLD_SIZE = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
|
||||
def _tcpstore_client() -> dist.Store:
|
||||
store = dist.TCPStore(
|
||||
host_name=MASTER_ADDR,
|
||||
port=MASTER_PORT,
|
||||
is_master=False,
|
||||
)
|
||||
store = dist.PrefixStore("debug_server", store)
|
||||
return store
|
||||
|
||||
|
||||
def fetch_all(endpoint: str) -> list[bytes]:
|
||||
store = _tcpstore_client()
|
||||
keys = [f"rank{r}" for r in range(WORLD_SIZE)]
|
||||
addrs = store.multi_get(keys)
|
||||
addrs = [f"{addr.decode()}/handler/{endpoint}" for addr in addrs]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
resps = executor.map(requests.post, addrs)
|
||||
|
||||
return addrs, resps
|
||||
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
def nav():
|
||||
return """
|
||||
<style>
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
}
|
||||
pre {
|
||||
white-space: pre-wrap;
|
||||
max-width: 100%;
|
||||
}
|
||||
</style>
|
||||
<h1>Torch Distributed Debug Server</h1>
|
||||
|
||||
<ul>
|
||||
<li><a href="/">Home</a></li>
|
||||
<li><a href="/stacks">Python Stack Traces</a></li>
|
||||
<li><a href="/fr_trace">FlightRecorder</a></li>
|
||||
<li><a href="/fr_trace_nccl">FlightRecorder NCCL</a></li>
|
||||
<li><a href="/profile">torch profiler</a></li>
|
||||
</ul>
|
||||
"""
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return nav()
|
||||
|
||||
|
||||
@app.route("/stacks")
|
||||
def stacks():
|
||||
addrs, resps = fetch_all("dump_traceback")
|
||||
|
||||
def generate():
|
||||
yield nav()
|
||||
|
||||
yield "<h2>Stacks</h2>"
|
||||
|
||||
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
|
||||
yield f"<h3>Rank {i}: {addr}</h3>"
|
||||
if resp.status_code != 200:
|
||||
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
|
||||
|
||||
stack = resp.text
|
||||
yield f"<pre>{stack}</pre>"
|
||||
|
||||
return generate()
|
||||
|
||||
|
||||
def format_json(blob: str):
|
||||
parsed = json.loads(blob)
|
||||
return json.dumps(parsed, indent=2)
|
||||
|
||||
|
||||
@app.route("/fr_trace")
|
||||
def fr_trace():
|
||||
addrs, resps = fetch_all("fr_trace_json")
|
||||
|
||||
def generate():
|
||||
yield nav()
|
||||
|
||||
yield "<h2>FlightRecorder</h2>"
|
||||
|
||||
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
|
||||
yield f"<h3>Rank {i}: {addr}</h3>"
|
||||
if resp.status_code != 200:
|
||||
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
|
||||
|
||||
stack = format_json(resp.text)
|
||||
yield f"<pre>{stack}</pre>"
|
||||
|
||||
return generate()
|
||||
|
||||
|
||||
@app.route("/fr_trace_nccl")
|
||||
def fr_trace_nccl():
|
||||
addrs, resps = fetch_all("dump_nccl_trace_json?onlyactive=true")
|
||||
|
||||
def generate():
|
||||
yield nav()
|
||||
|
||||
yield "<h2>FlightRecorder NCCL</h2>"
|
||||
|
||||
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
|
||||
yield f"<h3>Rank {i}: {addr}</h3>"
|
||||
if resp.status_code != 200:
|
||||
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
|
||||
|
||||
stack = format_json(resp.text)
|
||||
yield f"<pre>{stack}</pre>"
|
||||
|
||||
return generate()
|
||||
|
||||
|
||||
@app.route("/profile")
|
||||
def profiler():
|
||||
addrs, resps = fetch_all("torch_profile")
|
||||
|
||||
def generate():
|
||||
yield nav()
|
||||
|
||||
yield """
|
||||
<h2>torch profile</h2>
|
||||
<script>
|
||||
function stringToArrayBuffer(str) {
|
||||
const encoder = new TextEncoder();
|
||||
return encoder.encode(str).buffer;
|
||||
}
|
||||
async function openPerfetto(data) {
|
||||
const ui = window.open('https://ui.perfetto.dev/#!/');
|
||||
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
|
||||
|
||||
// Perfetto readiness handshake: PING until we receive PONG
|
||||
await new Promise((resolve, reject) => {
|
||||
const onMsg = (e) => {
|
||||
if (e.source === ui && e.data === 'PONG') {
|
||||
window.removeEventListener('message', onMsg);
|
||||
clearInterval(pinger);
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
window.addEventListener('message', onMsg);
|
||||
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
|
||||
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
|
||||
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
|
||||
|
||||
ui.postMessage({
|
||||
perfetto: {
|
||||
buffer: stringToArrayBuffer(JSON.stringify(data)),
|
||||
title: "torch profiler",
|
||||
fileName: "trace.json",
|
||||
}
|
||||
}, '*');
|
||||
}
|
||||
</script>
|
||||
"""
|
||||
|
||||
for i, addr, resp in zip(range(len(addrs)), addrs, resps):
|
||||
yield f"<h3>Rank {i}: {addr}</h3>"
|
||||
if resp.status_code != 200:
|
||||
yield f"<p>Failed to fetch: status={resp.status_code}</p>"
|
||||
|
||||
stack = resp.text
|
||||
yield f"""
|
||||
<script>
|
||||
function run{i}() {{
|
||||
var data = {stack};
|
||||
openPerfetto(data);
|
||||
}}
|
||||
</script>
|
||||
<button onclick="run{i}()">View {i}</button>
|
||||
"""
|
||||
|
||||
return generate()
|
||||
|
||||
|
||||
def _interactive_server() -> None:
|
||||
app.run(host="::", port=25999)
|
||||
|
||||
|
||||
def enable_debug_server() -> None:
|
||||
global _worker_server, _p
|
||||
|
||||
store = _tcpstore_client()
|
||||
|
||||
_worker_server = _WorkerServer("::", 0)
|
||||
store.set(f"rank{RANK}", f"http://{socket.gethostname()}:{_worker_server.port}")
|
||||
|
||||
if RANK == 0:
|
||||
_p = multiprocessing.Process(
|
||||
target=_interactive_server,
|
||||
)
|
||||
_p.start()
|
||||
@ -1485,6 +1485,7 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
use_full_backward: Optional[bool] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# Init parent
|
||||
super().__init__(
|
||||
@ -1517,6 +1518,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
# This will be set during init of derived schedules
|
||||
self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
|
||||
|
||||
# When using a custom backward function, we may or may not need autograd to be used
|
||||
# for the backward pass. This flag is used to determine whether or torch.is_grad_enabled()
|
||||
# check should be performed before the step function.
|
||||
self._backward_requires_autograd = backward_requires_autograd
|
||||
|
||||
if use_full_backward is not None:
|
||||
logger.warning(
|
||||
"Deprecation warning: 'use_full_backward' is no longer supported. "
|
||||
@ -1609,7 +1615,11 @@ class PipelineScheduleMulti(_PipelineSchedule):
|
||||
losses: a list to store the losses for each microbatch.
|
||||
return_outputs: whether to return the outputs from the last stage.
|
||||
"""
|
||||
if self._has_backward and not torch.is_grad_enabled():
|
||||
if (
|
||||
self._has_backward
|
||||
and self._backward_requires_autograd
|
||||
and not torch.is_grad_enabled()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"step() requires gradients to be enabled for backward computation; "
|
||||
"it should not be used under torch.no_grad() context. "
|
||||
@ -1891,7 +1901,7 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
Args:
|
||||
computation_type: The computation type for which to register the custom function
|
||||
custom_function: The function to execute when this computation type is encountered.
|
||||
Must have signature: (stage: _PipelineStageBase, mb_index: int, *args, **kwargs) -> None
|
||||
Must have signature: (action: _Action, ctx: _PipelineContext) -> None
|
||||
"""
|
||||
# Ensure that the computation type is valid
|
||||
if computation_type not in (
|
||||
@ -1900,10 +1910,13 @@ class _PipelineScheduleRuntime(PipelineScheduleMulti):
|
||||
BACKWARD_INPUT,
|
||||
BACKWARD_WEIGHT,
|
||||
OVERLAP_F_B,
|
||||
UNSHARD,
|
||||
RESHARD,
|
||||
REDUCE_GRAD,
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid computation type {computation_type}. Only FORWARD, FULL_BACKWARD, \
|
||||
BACKWARD_INPUT, BACKWARD_WEIGHT, and OVERLAP_F_B are supported."
|
||||
BACKWARD_INPUT, BACKWARD_WEIGHT, OVERLAP_F_B, UNSHARD, RESHARD and REDUCE_GRAD are supported."
|
||||
)
|
||||
|
||||
# Check if computation_type is already registered
|
||||
@ -2296,6 +2309,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
loss_fn: Optional[Union[Callable, _Loss]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
stages=stages,
|
||||
@ -2303,6 +2317,7 @@ class ScheduleLoopedBFS(_PipelineScheduleRuntime):
|
||||
loss_fn=loss_fn,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
|
||||
# 1. Create the pipeline_order (all ranks do this calculation)
|
||||
@ -2510,6 +2525,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
self.pp_group_size = stages[0].group_size
|
||||
super().__init__(
|
||||
@ -2520,6 +2536,7 @@ class ScheduleInterleaved1F1B(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.n_local_stages = len(stages)
|
||||
self.rank = stages[0].group_rank
|
||||
@ -2622,6 +2639,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -2634,6 +2652,7 @@ class ScheduleInterleavedZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.n_local_stages = len(stages)
|
||||
self.rank = stages[0].group_rank
|
||||
@ -2819,6 +2838,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -2831,6 +2851,7 @@ class ScheduleZBVZeroBubble(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
|
||||
self.pp_group_size, self._num_stages, style="v"
|
||||
@ -2995,6 +3016,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
|
||||
output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
|
||||
scale_grads: bool = True,
|
||||
backward_requires_autograd: bool = True,
|
||||
):
|
||||
# TODO: we dont support input/weight backward split with torch.compile
|
||||
_check_torch_compile_compatibility(stages, self.__class__.__name__)
|
||||
@ -3007,6 +3029,7 @@ class ScheduleDualPipeV(_PipelineScheduleRuntime):
|
||||
kwargs_chunk_spec=kwargs_chunk_spec,
|
||||
output_merge_spec=output_merge_spec,
|
||||
scale_grads=scale_grads,
|
||||
backward_requires_autograd=backward_requires_autograd,
|
||||
)
|
||||
self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
|
||||
self.pp_group_size, self._num_stages, style="v"
|
||||
|
||||
@ -2,6 +2,8 @@ import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
from torch.utils._config_module import Config, install_config_module
|
||||
|
||||
|
||||
# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors.
|
||||
no_data_dependent_graph_break = (
|
||||
@ -100,7 +102,11 @@ backed_size_oblivious = False
|
||||
# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking.
|
||||
skip_dtype_check_in_meta_registrations = False
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
# Experimental: If True, graph module will register fx metadata during recompile()
|
||||
enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated]
|
||||
default=False,
|
||||
env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE",
|
||||
)
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
||||
@ -20,6 +20,7 @@ from torch.nn.modules.module import _addindent
|
||||
from torch.package import Importer, PackageExporter, PackageImporter, sys_importer
|
||||
|
||||
from ._compatibility import compatibility
|
||||
from .experimental import _config as fx_experimental_config
|
||||
from .graph import (
|
||||
_BoxedCodeGen,
|
||||
_custom_builtins,
|
||||
@ -858,14 +859,15 @@ class {module_name}(torch.nn.Module):
|
||||
called after editing the contained ``graph``, otherwise the generated
|
||||
code of this ``GraphModule`` will be out of date.
|
||||
"""
|
||||
# Do not import anything inside recompile, it might slow down the
|
||||
# function and cause perf regression. Import outside of the method instead.
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
|
||||
from torch._dynamo import config as dynamo_config
|
||||
|
||||
python_code = self._graph.python_code(
|
||||
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
|
||||
root_module="self",
|
||||
record_func=fx_experimental_config.enrich_profiler_metadata,
|
||||
)
|
||||
self._code = python_code.src
|
||||
self._lineno_map = python_code._lineno_map
|
||||
@ -874,7 +876,7 @@ class {module_name}(torch.nn.Module):
|
||||
cls = type(self)
|
||||
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
|
||||
|
||||
if dynamo_config.enrich_profiler_metadata:
|
||||
if fx_experimental_config.enrich_profiler_metadata:
|
||||
# Generate metadata and register for profiler augmentation
|
||||
node_metadata: dict[int, dict[str, Any]] = {}
|
||||
for i, node in enumerate(self._graph.nodes):
|
||||
|
||||
@ -1771,6 +1771,22 @@ class MultiProcContinuousTest(TestCase):
|
||||
cls._run_test_given_id(test_id)
|
||||
completion_queue.put(test_id)
|
||||
except BaseException as ex: # noqa: B036
|
||||
if isinstance(ex, SystemExit):
|
||||
# Get exit code from the process
|
||||
exit_code = getattr(ex, "code", None)
|
||||
|
||||
# Look up exit code in TEST_SKIPS to see if it is a valid skip
|
||||
skip_entry = next(
|
||||
(v for v in TEST_SKIPS.values() if v.exit_code == exit_code),
|
||||
None,
|
||||
)
|
||||
|
||||
# If we found an entry, we want to skip the test and the object back to the main process
|
||||
if skip_entry:
|
||||
completion_queue.put(unittest.SkipTest(skip_entry.message))
|
||||
# Skip exception handling below, move to main thread for processing the skip
|
||||
continue
|
||||
|
||||
raised_exception = True
|
||||
# Send the exception and stack trace back to the dispatcher
|
||||
exc_info = sys.exc_info()
|
||||
@ -1892,6 +1908,8 @@ class MultiProcContinuousTest(TestCase):
|
||||
# Wait for the workers to finish the test
|
||||
for i, completion_queue in enumerate(self.completion_queues):
|
||||
rv = completion_queue.get()
|
||||
if isinstance(rv, unittest.SkipTest):
|
||||
raise rv
|
||||
if isinstance(rv, BaseException):
|
||||
# Hit an exception, re-raise it in the main process.
|
||||
logger.warning(
|
||||
|
||||
@ -114,8 +114,6 @@ class ProfilingMode(Enum):
|
||||
PROFILING = 3
|
||||
|
||||
# Set by parse_cmd_line_args() if called
|
||||
CI_FUNCTORCH_ROOT = ""
|
||||
CI_PT_ROOT = ""
|
||||
CI_TEST_PREFIX = ""
|
||||
DISABLED_TESTS_FILE = ""
|
||||
GRAPH_EXECUTOR : Optional[ProfilingMode] = None
|
||||
@ -959,8 +957,6 @@ def _get_test_report_path():
|
||||
return os.path.join('test-reports', test_source)
|
||||
|
||||
def parse_cmd_line_args():
|
||||
global CI_FUNCTORCH_ROOT
|
||||
global CI_PT_ROOT
|
||||
global CI_TEST_PREFIX
|
||||
global DISABLED_TESTS_FILE
|
||||
global GRAPH_EXECUTOR
|
||||
@ -1039,10 +1035,8 @@ def parse_cmd_line_args():
|
||||
|
||||
set_rng_seed()
|
||||
|
||||
# CI Prefix path used only on CI environment
|
||||
# CI Prefix path used only on CI environment
|
||||
CI_TEST_PREFIX = str(Path(os.getcwd()))
|
||||
CI_PT_ROOT = str(Path(os.getcwd()).parent)
|
||||
CI_FUNCTORCH_ROOT = str(os.path.join(Path(os.getcwd()).parent, "functorch"))
|
||||
|
||||
def wait_for_process(p, timeout=None):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user