Compare commits

..

1 Commits

Author SHA1 Message Date
f31507da3e add manual bucketing 2025-10-28 10:45:01 -07:00
13 changed files with 1972 additions and 2077 deletions

View File

@ -129,7 +129,7 @@ function install_129 {
}
function install_128 {
CUDNN_VERSION=9.8.0.87
CUDNN_VERSION=9.10.2.21
echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1"
# install CUDA 12.8.1 in the same container
install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux

View File

@ -272,6 +272,18 @@ def smoke_test_cuda(
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion()
print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}")
torch_cudnn_runtime_version = tuple(
[int(x) for x in torch_cudnn_version.split(".")]
)
if torch_cudnn_runtime_version != torch_cudnn_compile_version:
raise RuntimeError(
"cuDNN runtime version doesn't match comple version. "
f"Loaded: {torch_cudnn_runtime_version} "
f"Expected: {torch_cudnn_compile_version}"
)
if sys.platform in ["linux", "linux2"]:
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
print(f"Torch nccl; version: {torch_nccl_version}")

View File

@ -1,8 +1,3 @@
---
name: docstring
description: Write docstrings for PyTorch functions and methods following PyTorch conventions. Use when writing or updating docstrings in PyTorch code.
---
# PyTorch Docstring Writing Guide
This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`.

View File

@ -1,385 +0,0 @@
---
name: skill-writer
description: Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill, or needs help with SKILL.md files, frontmatter, or skill structure.
---
# Skill Writer
This Skill helps you create well-structured Agent Skills for Claude Code that follow best practices and validation requirements.
## When to use this Skill
Use this Skill when:
- Creating a new Agent Skill
- Writing or updating SKILL.md files
- Designing skill structure and frontmatter
- Troubleshooting skill discovery issues
- Converting existing prompts or workflows into Skills
## Instructions
### Step 1: Determine Skill scope
First, understand what the Skill should do:
1. **Ask clarifying questions**:
- What specific capability should this Skill provide?
- When should Claude use this Skill?
- What tools or resources does it need?
- Is this for personal use or team sharing?
2. **Keep it focused**: One Skill = one capability
- Good: "PDF form filling", "Excel data analysis"
- Too broad: "Document processing", "Data tools"
### Step 2: Choose Skill location
Determine where to create the Skill:
**Personal Skills** (`~/.claude/skills/`):
- Individual workflows and preferences
- Experimental Skills
- Personal productivity tools
**Project Skills** (`.claude/skills/`):
- Team workflows and conventions
- Project-specific expertise
- Shared utilities (committed to git)
### Step 3: Create Skill structure
Create the directory and files:
```bash
# Personal
mkdir -p ~/.claude/skills/skill-name
# Project
mkdir -p .claude/skills/skill-name
```
For multi-file Skills:
```
skill-name/
├── SKILL.md (required)
├── reference.md (optional)
├── examples.md (optional)
├── scripts/
│ └── helper.py (optional)
└── templates/
└── template.txt (optional)
```
### Step 4: Write SKILL.md frontmatter
Create YAML frontmatter with required fields:
```yaml
---
name: skill-name
description: Brief description of what this does and when to use it
---
```
**Field requirements**:
- **name**:
- Lowercase letters, numbers, hyphens only
- Max 64 characters
- Must match directory name
- Good: `pdf-processor`, `git-commit-helper`
- Bad: `PDF_Processor`, `Git Commits!`
- **description**:
- Max 1024 characters
- Include BOTH what it does AND when to use it
- Use specific trigger words users would say
- Mention file types, operations, and context
**Optional frontmatter fields**:
- **allowed-tools**: Restrict tool access (comma-separated list)
```yaml
allowed-tools: Read, Grep, Glob
```
Use for:
- Read-only Skills
- Security-sensitive workflows
- Limited-scope operations
### Step 5: Write effective descriptions
The description is critical for Claude to discover your Skill.
**Formula**: `[What it does] + [When to use it] + [Key triggers]`
**Examples**:
✅ **Good**:
```yaml
description: Extract text and tables from PDF files, fill forms, merge documents. Use when working with PDF files or when the user mentions PDFs, forms, or document extraction.
```
✅ **Good**:
```yaml
description: Analyze Excel spreadsheets, create pivot tables, and generate charts. Use when working with Excel files, spreadsheets, or analyzing tabular data in .xlsx format.
```
❌ **Too vague**:
```yaml
description: Helps with documents
description: For data analysis
```
**Tips**:
- Include specific file extensions (.pdf, .xlsx, .json)
- Mention common user phrases ("analyze", "extract", "generate")
- List concrete operations (not generic verbs)
- Add context clues ("Use when...", "For...")
### Step 6: Structure the Skill content
Use clear Markdown sections:
```markdown
# Skill Name
Brief overview of what this Skill does.
## Quick start
Provide a simple example to get started immediately.
## Instructions
Step-by-step guidance for Claude:
1. First step with clear action
2. Second step with expected outcome
3. Handle edge cases
## Examples
Show concrete usage examples with code or commands.
## Best practices
- Key conventions to follow
- Common pitfalls to avoid
- When to use vs. not use
## Requirements
List any dependencies or prerequisites:
```bash
pip install package-name
```
## Advanced usage
For complex scenarios, see [reference.md](reference.md).
```
### Step 7: Add supporting files (optional)
Create additional files for progressive disclosure:
**reference.md**: Detailed API docs, advanced options
**examples.md**: Extended examples and use cases
**scripts/**: Helper scripts and utilities
**templates/**: File templates or boilerplate
Reference them from SKILL.md:
```markdown
For advanced usage, see [reference.md](reference.md).
Run the helper script:
\`\`\`bash
python scripts/helper.py input.txt
\`\`\`
```
### Step 8: Validate the Skill
Check these requirements:
✅ **File structure**:
- [ ] SKILL.md exists in correct location
- [ ] Directory name matches frontmatter `name`
✅ **YAML frontmatter**:
- [ ] Opening `---` on line 1
- [ ] Closing `---` before content
- [ ] Valid YAML (no tabs, correct indentation)
- [ ] `name` follows naming rules
- [ ] `description` is specific and < 1024 chars
✅ **Content quality**:
- [ ] Clear instructions for Claude
- [ ] Concrete examples provided
- [ ] Edge cases handled
- [ ] Dependencies listed (if any)
✅ **Testing**:
- [ ] Description matches user questions
- [ ] Skill activates on relevant queries
- [ ] Instructions are clear and actionable
### Step 9: Test the Skill
1. **Restart Claude Code** (if running) to load the Skill
2. **Ask relevant questions** that match the description:
```
Can you help me extract text from this PDF?
```
3. **Verify activation**: Claude should use the Skill automatically
4. **Check behavior**: Confirm Claude follows the instructions correctly
### Step 10: Debug if needed
If Claude doesn't use the Skill:
1. **Make description more specific**:
- Add trigger words
- Include file types
- Mention common user phrases
2. **Check file location**:
```bash
ls ~/.claude/skills/skill-name/SKILL.md
ls .claude/skills/skill-name/SKILL.md
```
3. **Validate YAML**:
```bash
cat SKILL.md | head -n 10
```
4. **Run debug mode**:
```bash
claude --debug
```
## Common patterns
### Read-only Skill
```yaml
---
name: code-reader
description: Read and analyze code without making changes. Use for code review, understanding codebases, or documentation.
allowed-tools: Read, Grep, Glob
---
```
### Script-based Skill
```yaml
---
name: data-processor
description: Process CSV and JSON data files with Python scripts. Use when analyzing data files or transforming datasets.
---
# Data Processor
## Instructions
1. Use the processing script:
\`\`\`bash
python scripts/process.py input.csv --output results.json
\`\`\`
2. Validate output with:
\`\`\`bash
python scripts/validate.py results.json
\`\`\`
```
### Multi-file Skill with progressive disclosure
```yaml
---
name: api-designer
description: Design REST APIs following best practices. Use when creating API endpoints, designing routes, or planning API architecture.
---
# API Designer
Quick start: See [examples.md](examples.md)
Detailed reference: See [reference.md](reference.md)
## Instructions
1. Gather requirements
2. Design endpoints (see examples.md)
3. Document with OpenAPI spec
4. Review against best practices (see reference.md)
```
## Best practices for Skill authors
1. **One Skill, one purpose**: Don't create mega-Skills
2. **Specific descriptions**: Include trigger words users will say
3. **Clear instructions**: Write for Claude, not humans
4. **Concrete examples**: Show real code, not pseudocode
5. **List dependencies**: Mention required packages in description
6. **Test with teammates**: Verify activation and clarity
7. **Version your Skills**: Document changes in content
8. **Use progressive disclosure**: Put advanced details in separate files
## Validation checklist
Before finalizing a Skill, verify:
- [ ] Name is lowercase, hyphens only, max 64 chars
- [ ] Description is specific and < 1024 chars
- [ ] Description includes "what" and "when"
- [ ] YAML frontmatter is valid
- [ ] Instructions are step-by-step
- [ ] Examples are concrete and realistic
- [ ] Dependencies are documented
- [ ] File paths use forward slashes
- [ ] Skill activates on relevant queries
- [ ] Claude follows instructions correctly
## Troubleshooting
**Skill doesn't activate**:
- Make description more specific with trigger words
- Include file types and operations in description
- Add "Use when..." clause with user phrases
**Multiple Skills conflict**:
- Make descriptions more distinct
- Use different trigger words
- Narrow the scope of each Skill
**Skill has errors**:
- Check YAML syntax (no tabs, proper indentation)
- Verify file paths (use forward slashes)
- Ensure scripts have execute permissions
- List all dependencies
## Examples
See the documentation for complete examples:
- Simple single-file Skill (commit-helper)
- Skill with tool permissions (code-reviewer)
- Multi-file Skill (pdf-processing)
## Output format
When creating a Skill, I will:
1. Ask clarifying questions about scope and requirements
2. Suggest a Skill name and location
3. Create the SKILL.md file with proper frontmatter
4. Include clear instructions and examples
5. Add supporting files if needed
6. Provide testing instructions
7. Validate against all requirements
The result will be a complete, working Skill that follows all best practices and validation rules.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,171 +0,0 @@
#pragma once
#include <ATen/core/Tensor.h>
namespace at::native {
using at::blas::ScalingType;
using at::blas::SwizzleType;
namespace {
// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492
c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) {
if (resolve_conj && tensor.is_conj()) {
return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj());
} else {
return c10::MaybeOwned<Tensor>::borrowed(tensor);
}
}
c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) {
if (tensor.is_non_overlapping_and_dense()) { // common case
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor);
}
IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
transpose_tensor = false;
return resolve_conj_if_indicated(tensor, !transpose_result);
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
transpose_tensor = true;
return resolve_conj_if_indicated(tensor, transpose_result);
} else {
transpose_tensor = true;
return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
}
}
c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) {
if (tensor.is_non_overlapping_and_dense()) { // common case
transpose_tensor = tensor.is_contiguous();
return resolve_conj_if_indicated(tensor, true);
}
IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
transpose_tensor = false;
return resolve_conj_if_indicated(tensor, true);
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
transpose_tensor = true;
return resolve_conj_if_indicated(tensor, true);
} else {
transpose_tensor = true;
return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous));
}
}
} // namespace
/**
* @brief Prepares matrices for CUBLAS operation
*
* This constructor prepares tensors for CUBLAS
* The main difference is that PyTorch uses row-major as the default and
* CUBLAS expects column-major.
*
* @details
* To enable row-major output while using CUBLAS,
* we use the mathematical identity that (A × B)^T = B^T × A^T.
*
* Transpose in this context refers to Cublas's(Fortran) definition of transpose (row-major)
* T = row-major, N = col-major
*
* Example:
* For matrices A (M×K)(row-major) and B (K×N)(row-major):
* - Standard multiplication: A × B = (M×K) × (K×N) = M×N result (row-major)
* - Using our transpose trick: (B^T × A^T) = (N×K)(T) × (K×M)(T) = N×M(N)
* - However, since the output form cublas is column-major this is
* - equivalent to an output of size MxN row-major as expected
*
* The transpose flags are derived from the layouts of the passed in tensors
*
* If the operands are in packed float4 format, `k`, `lda` and `ldb` are adjusted
* to their unpacked values to match what cuBLAS expects.
*
* @param mat1 First input matrix
* @param mat2 Second input matrix
* @param c Output matrix (result)
* @param scale_a Optional scaling factor for first matrix
* @param scale_b Optional scaling factor for second matrix
* @param scale_result Optional scaling factor for result
*/
struct cublasCommonArgs {
cublasCommonArgs(
const Tensor& mat1,
const Tensor& mat2,
Tensor& c,
const std::optional<Tensor>& scale_a = std::nullopt,
const std::optional<Tensor>& scale_b = std::nullopt,
const std::optional<Tensor>& scale_result = std::nullopt,
const std::optional<ScalingType>& scaling_choice_a = std::nullopt,
const std::optional<ScalingType>& scaling_choice_b = std::nullopt) {
bool transpose_result = false, transpose_a = false, transpose_b = false;
result = prepare_matrix_for_cublas(c, transpose_result);
mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result);
matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result);
// Handle scale tensors if provided
if (scale_a && scale_b) {
// By default since we return in row-major we run the gemm
// as B.T @ A.T, check transpose_result to determine if we flip the scales
scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr();
scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type();
scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a;
scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr();
scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type();
scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b;
}
if (scale_result) {
scale_result_ptr = scale_result->data_ptr();
scale_result_dtype = scale_result->scalar_type();
}
// Update transpose flags
if (transpose_result) {
transpose_a = !transpose_a;
transpose_b = !transpose_b;
}
auto sizes_a = mata->sizes();
auto sizes_b = matb->sizes();
m = sizes_a[transpose_result ? 1 : 0];
k = sizes_a[transpose_result ? 0 : 1];
n = sizes_b[transpose_result ? 0 : 1];
lda = mata->stride((transpose_a == transpose_result) ? 1 : 0);
ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0);
result_ld = result->stride(transpose_result ? 0 : 1);
transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n';
transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n';
// cuBLAS expects unpacked values of `k`, `lda` and `ldb`, adjust for 4x2 packing
// if the gemm operands are in packed float4
if (mat1.dtype() == at::kFloat4_e2m1fn_x2 && mat2.dtype() == at::kFloat4_e2m1fn_x2) {
k = k * 2;
lda = lda * 2;
ldb = ldb * 2;
}
}
// Matrix members
char transa, transb;
int64_t m, n, k;
int64_t lda, ldb, result_ld;
c10::MaybeOwned<Tensor> mata, matb, result;
// Scale members
void* scale_mata_ptr = nullptr;
void* scale_matb_ptr = nullptr;
void* scale_result_ptr = nullptr;
std::optional<c10::ScalarType> scale_mata_dtype;
std::optional<ScalingType> scaling_mata_type;
std::optional<c10::ScalarType> scale_matb_dtype;
std::optional<ScalingType> scaling_matb_type;
std::optional<c10::ScalarType> scale_result_dtype;
};
} // namespace at::native

View File

@ -279,7 +279,6 @@ class SymmetricMemoryTest(MultiProcContinuousTest):
# MultiProcContinuousTest will skip all the following tests if a test fails (
# we should fix this too). We still want to get the test signals for the core
# symmetric memory APIs when Async TP ops fail.
@skip_if_rocm_multiprocess # AsyncTP is not yet supported on ROCm
@instantiate_parametrized_tests
@requires_cuda_p2p_access()
class AsyncTPTest(MultiProcContinuousTest):

View File

@ -2,11 +2,8 @@
# flake8: noqa: B950
import functools
import json
import os
import random
import string
import tempfile
import unittest
import warnings
from collections import namedtuple
@ -7048,120 +7045,6 @@ class TestLearnableBiases(InductorTestCase):
def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device):
self._test_flex_attention_with_dynamic_max_autotune(device)
@skip_on_cpu
def test_flex_attention_logging(self, device):
with tempfile.TemporaryDirectory() as tmpdir:
log_file = os.path.join(tmpdir, "flex_attention_configs")
with patch.dict(
os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file}
):
query = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
key = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
value = torch.randn(
1,
2,
128,
64,
device=device,
dtype=torch.float16,
requires_grad=True,
)
def score_mod(score, b, h, q_idx, kv_idx):
return score * 2
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = torch.compile(create_block_mask)(
causal_mask, 1, 1, 128, 128, device=device
)
compiled_flex = torch.compile(
flex_attention, mode="max-autotune-no-cudagraphs"
)
out = compiled_flex(
query=query,
key=key,
value=value,
score_mod=score_mod,
block_mask=block_mask,
)
out.sum().backward()
json_file = log_file + ".json"
self.assertTrue(
os.path.exists(json_file), f"Log file {json_file} was not created"
)
with open(json_file) as f:
log_data = json.load(f)
self.assertIsInstance(log_data, list)
self.assertEqual(len(log_data), 2)
keys_seen = [next(iter(entry.keys())) for entry in log_data]
expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)"
expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)"
self.assertIn(expected_fwd_key, keys_seen)
self.assertIn(expected_bwd_key, keys_seen)
for entry in log_data:
self.assertIsInstance(entry, dict)
self.assertEqual(len(entry), 1)
dims_key = next(iter(entry.keys()))
choices = entry[dims_key]
kernel_type = eval(dims_key)[0]
self.assertIsInstance(choices, list)
self.assertGreater(len(choices), 0)
for i, choice in enumerate(choices):
self.assertIn("type", choice)
self.assertIn("time", choice)
if choice["type"] == "triton":
self.assertIn("num_warps", choice)
self.assertIn("num_stages", choice)
if kernel_type == "forward":
self.assertIn("BLOCK_M", choice)
self.assertIn("BLOCK_N", choice)
self.assertNotIn("BLOCK_M1", choice)
elif kernel_type == "backward":
self.assertIn("BLOCK_M1", choice)
self.assertIn("BLOCK_N1", choice)
self.assertIn("BLOCK_M2", choice)
self.assertIn("BLOCK_N2", choice)
self.assertNotIn("BLOCK_M", choice)
self.assertNotIn("BLOCK_N", choice)
if i > 0:
self.assertLessEqual(choices[0]["time"], choice["time"])
@skip_on_cpu
def test_inspect_bug(self, device):
# https://github.com/pytorch/pytorch/issues/139374

View File

@ -445,7 +445,7 @@ use_numpy_random_stream = False
enable_cpp_guard_manager = True
# Use C++ guard manager for symbolic shapes
enable_cpp_symbolic_shape_guards = False
enable_cpp_symbolic_shape_guards = not is_fbcode()
# Enable tracing through contextlib.contextmanager
enable_trace_contextlib = True

View File

@ -117,9 +117,9 @@ def bucket_reduce_scatter(
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
return (
node.op == "call_function"
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
return node.op == "call_function" and (
node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
or node.target == torch.ops._c10d_functional.all_gather_into_tensor_out.default
)

View File

@ -0,0 +1,583 @@
from __future__ import annotations
import heapq
import itertools
import re
from collections import Counter, defaultdict
from typing import Any, Callable, Optional, Union
import torch
import torch.fx as fx
from torch._dynamo.graph_deduplication import _stable_topological_sort
from torch._inductor.fx_passes.bucketing import (
is_all_gather_into_tensor as is_all_gather,
is_reduce_scatter_tensor as is_reduce_scatter,
is_wait_tensor,
merge_all_gather_bucket,
merge_reduce_scatter_bucket,
)
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
bucket_key,
OverlapPreservingBucketer,
)
from torch._inductor.fx_passes.overlap_scheduling import (
CollectiveInfo,
is_compute_node,
OverlapScheduler,
)
from torch.utils._ordered_set import OrderedSet
def _get_module_stack(node: fx.Node) -> list[tuple[str, type[Any]]]:
if node.meta.get("nn_module_stack", "") == "":
if node.meta.get("fwd_nn_module_stack", "") != "":
return list(node.meta["fwd_nn_module_stack"].values())
return []
return list(node.meta["nn_module_stack"].values())
def _addindent(s_: str, num_spaces: int) -> str:
s: list[str] = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first: str = s.pop(0)
s: list[str] = [(num_spaces * " ") + line for line in s]
joint_s: str = "\n".join(s)
joint_s = first + "\n" + joint_s
return joint_s
class Container:
def __init__(self, name: str, klass: type[Any]) -> None:
self.name: str = name
self.klass: type[Any] = klass
self.data: list[fx.Node] = []
self.unique_nodes: OrderedSet[fx.Node] = OrderedSet()
self.children: dict[str, Container] = {}
def add(self, data: fx.Node) -> None:
if data not in self.unique_nodes:
self.data.append(data)
self.unique_nodes.add(data)
def get_child(
self, module_stack: str, klass: Optional[type[Any]] = None
) -> Container:
if module_stack not in self.children:
new_stack = Container(module_stack, klass or self.klass)
self.children[module_stack] = new_stack
return self.children[module_stack]
def __getitem__(self, name: str) -> Container:
return self.children[name]
def __getattr__(self, name: str) -> Container:
return self.children[name]
def __repr__(self) -> str:
child_lines: list[str] = []
for name, child in self.children.items():
mod_str = repr(child)
mod_str = _addindent(mod_str, 2)
child_lines.append(f"({name}): {mod_str}")
main_str = f"{self.klass.__name__}("
if child_lines:
main_str += "\n " + "\n ".join(child_lines) + "\n"
main_str += ")"
return main_str
def graph_view(self) -> fx.Graph:
return _make_subgraph(self.data)
def _clean_stack_name(val: str) -> str:
name: str = (
val.replace("L['self']", "Model")
.replace("_modules['", "")
.replace("['", ".")
.replace("']", "")
)
return name
def _find_key_nodes(nodes: list[fx.Node]) -> tuple[list[fx.Node], list[fx.Node]]:
root = []
outputs = []
nodes_set = OrderedSet(nodes)
for node in nodes:
for x in node.all_input_nodes:
if x not in nodes_set:
root.append(x)
if all(x not in nodes_set for x in node.users):
outputs.append(node)
return root, outputs
def _make_subgraph(nodes: list[fx.Node]) -> fx.Graph:
placeholders, outputs = _find_key_nodes(nodes)
new_graph = torch.fx.Graph()
env: dict[fx.Node, fx.Node] = {}
def env_lookup(x: fx.Node) -> fx.Node:
assert x in env, f"Dependent node {x} not in env when creating downstream node"
return env[x]
def node_copy(
node: fx.Node, arg_transform: Callable[[fx.Node], fx.Node]
) -> fx.Node:
if node not in env:
new_node = new_graph.node_copy(node, arg_transform=arg_transform)
env[node] = new_node
else:
new_node = env[node]
return new_node
for node in placeholders:
env[node] = new_graph.placeholder(node.name)
for node in nodes:
if node in placeholders:
continue
else:
new_node = node_copy(node, env_lookup)
new_node.meta = node.meta.copy()
out_node = [env[x] for x in outputs]
new_graph.output(out_node)
return new_graph
def _is_root(stack: str) -> bool:
return stack == ""
def make_graph_view(graph: fx.Graph) -> Container:
"""
Code from: https://github.com/meta-pytorch/autoparallel/pull/158
Make a graph view from the fx.Graph. This is a tree structure that
represents the module hierarchy of the graph, and enables us to
easily find the nodes that belong to each module, and gives a slightly
easier way of visualize different parts of the graph by extracting
subgraphs that belong to a particular module FQN.
For example, if we have the following model with module hierarchy:
Transformer(
(tok_embeddings): Embedding(128256, 4096)
(layers): ModuleDict(
(0): TransformerBlock(
(attention): Attention(
(wq): Linear(in_features=4096, out_features=4096, bias=False)
(wk): Linear(in_features=4096, out_features=1024, bias=False)
(wv): Linear(in_features=4096, out_features=1024, bias=False)
(wo): Linear(in_features=4096, out_features=4096, bias=False)
(sdpa): ScaledDotProductAttention()
)
(feed_forward): FeedForward(
(w1): Linear(in_features=4096, out_features=14336, bias=False)
(w2): Linear(in_features=14336, out_features=4096, bias=False)
(w3): Linear(in_features=4096, out_features=14336, bias=False)
)
(attention_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
(ffn_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
)
)
(norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True)
(output): Linear(in_features=4096, out_features=128256, bias=False)
)
Then we can get a GraphView for the fx.Graph that enables us to do
graph_view = make_graph_view(graph)
subgraph = graph_view.layers["0"].attention.graph_view()
where subgraph is a fx.Graph that contains all the nodes that belong to
Transformer.layers['0'].attention, and whose inputs are all inputs to this
region of the graph, and whose outputs are all outputs of this region of
the graph. This returns a new graph with new nodes, so we shouldn't use it
for graph manipulations, but it is useful to visualize what a particular
part of a larger graph looks like.
Additionally, you can also query the original nodes in that region with
`graph_view.layers['0'].attention.data`, which returns a list of all the
nodes that belong to Transformer.layers['0'].attention.
"""
nodes: list[fx.Node] = list(graph.nodes)
nodes_by_module_stack_root: Container | None = None
for node in nodes:
for module_stack, module_class in _get_module_stack(node):
module_stack = _clean_stack_name(module_stack)
nodes_by_module_stack: Container | None = nodes_by_module_stack_root
for name in module_stack.split("."):
if nodes_by_module_stack is None:
nodes_by_module_stack = Container(name, module_class)
nodes_by_module_stack_root = nodes_by_module_stack
if _is_root(module_stack):
new_stack: Container = nodes_by_module_stack
else:
new_stack = nodes_by_module_stack.get_child(name, module_class)
nodes_by_module_stack = new_stack
nodes_by_module_stack.add(node)
assert nodes_by_module_stack_root is not None, "empty node in the graph"
return nodes_by_module_stack_root
def decode_module(module_bucket_plans: list[str]) -> list[list[str] | str]:
"""
Convert abbreviated FQNs to the actual FQNs.
Currently, we support the decoding of these abbreviations:
(1) layers.[0-2] -> [layers.0], [layers.1], [layers.2]
(layers are split three separate buckets)
(2) norm+output -> [norm, output]
(norm and output are in one bucket)
"""
full_plan: list[list[str] | str] = []
for module_name in module_bucket_plans:
if "+" in module_name:
full_plan.append(module_name.split("+"))
continue
match = re.search(r"\[(\d+)-(\d+)\]", module_name)
if not match:
full_plan.append(module_name)
else:
start, end = map(int, match.groups())
prefix = module_name[: match.start()]
suffix = module_name[match.end() :]
full_plan.extend([f"{prefix}{i}{suffix}" for i in range(start, end + 1)])
return full_plan
def get_subgraph_by_path(
graph_view: Container, paths: Union[str, list[str]]
) -> list[fx.Node]:
"""
Get subgraph by path(s).
Args:
graph_view (object): Root graph view object.
paths (str or list of str): Path(s) to subgraph.
Returns:
object: Subgraph object or node.
"""
def get_node_by_path(node: Container, path: str) -> Container:
for p in path.split("."):
if p in node.children:
node = node.children[p]
else:
return Container("", object)
return node
if isinstance(paths, list):
nodes = list(
itertools.chain.from_iterable(
get_node_by_path(graph_view, p).data for p in paths
)
)
return nodes
else:
node = get_node_by_path(graph_view, paths)
return node.data
class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
"""
Buckets collective operations based on user specifications.
The actual bucket happens in bucket_collectives, where all-gathers/reduce-scatters in
`nodes` will be buckted one single all-gather/reduce-scatter.
"""
def __init__(
self, node_users: dict[fx.Node, OrderedSet[fx.Node]], *args: Any, **kwargs: Any
):
super().__init__(*args, **kwargs)
self.node_users = node_users
self.wait_to_node_map: dict[fx.Node, fx.Node] = defaultdict()
def _check_recursive_dep(
self,
node: fx.Node,
target_op: str,
dep_dict: dict[torch.fx.Node, OrderedSet[torch.fx.Node]],
) -> bool:
"""
Check if the node is directly used for fetch parameters/gradients
TODO (ruisizhang123): currently, we assume the node only pre-fetch/update one parameter/gradient
We should handle multiple parameters/gradients update case by checking if there are non closure
computes along the path from primal/output to coll_node
"""
deps: OrderedSet[fx.Node] = dep_dict[node]
seen_target_op = 0
for d in deps:
if d.op == target_op:
seen_target_op += 1
return seen_target_op == 1
def _bucket_group(self, coll_nodes: list[fx.Node]) -> None:
assert len(coll_nodes) > 0, "bucketed coll_nodes should have nonzero node"
waits = [self.collective_info[n].wait_node for n in coll_nodes]
# Use earliest wait insertion point
first_wait = min(waits, key=lambda w: self.node_idx[w])
# Find insertion location
first = coll_nodes[0]
next_node = first
while next_node in coll_nodes:
next_node = next_node.next
if is_all_gather(first):
new_nodes, replacements = merge_all_gather_bucket(
self.graph,
coll_nodes,
wait_insertion_point=first_wait,
insert_before=next_node,
mode="custom_ops",
)
elif is_reduce_scatter(first):
new_nodes, replacements = merge_reduce_scatter_bucket(
self.graph,
coll_nodes,
wait_insertion_point=first_wait,
insert_before=next_node,
mode="custom_ops",
)
else:
raise ValueError(
"bucket non all_gather/reduce_scatter node is not supported"
)
# Identify the new wait and start
new_waits = [n for n in new_nodes if is_wait_tensor(n)]
assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}"
new_wait = new_waits[0]
new_start = new_wait.args[0]
assert isinstance(new_start, fx.Node)
node_type = (
"bucketed_all_gather" if is_all_gather(first) else "bucketed_reduce_scatter"
)
for n in new_nodes:
n.meta["nn_module_stack"] = coll_nodes[0].meta.get("nn_module_stack", "")
n.meta["fwd_nn_module_stack"] = coll_nodes[0].meta.get(
"fwd_nn_module_stack", ""
)
if n == new_wait:
node_type = node_type + "_wait"
n.meta["manual_bucket_node_type"] = node_type
if "wait" in node_type:
self.wait_to_node_map[n] = new_wait
def manual_bucket_collectives(self, nodes: list[fx.Node]) -> None:
"""
Bucket all all-gather/reduce-scatter nodes from nodes into one all-gather/reduce-scatter.
"""
# Filter out valid collectives
collectives = [n for n in nodes if n in self.collective_info]
if collectives == []:
return
grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for node in collectives:
key = bucket_key(node)
if not (is_all_gather(node) or is_reduce_scatter(node)):
continue
# We only want to bucket all-gather/reduce-scatter that
# 1. all_gather that have ancestors dependent only on input placeholder(parameters)
# 2. reduce scatter that the wait user node is returned as output(gradients)
if is_all_gather(node) and not self._check_recursive_dep(
node, "placeholder", self.node_ancestors
):
continue
if is_reduce_scatter(node) and not self._check_recursive_dep(
self.collective_info[node].wait_node, "output", self.node_users
):
continue
if key is not None:
grouped_collectives[key].add(node)
for key, nodes in grouped_collectives.items():
self._bucket_group(list(nodes))
class ManualOverlapScheduler(OverlapScheduler):
"""
Scheduler that manual buckets and reorders collective nodes based on module_bucket_plans
"""
def __init__(self, gm: fx.GraphModule, module_bucket_plans: list[list[str] | str]):
super().__init__(
gm,
max_in_flight_gb=0.0,
max_compute_pre_fetch=0,
collective_bucketing=True,
insert_overlap_deps=True,
compute_overlap_multipler=0.0,
max_coll_distance=0,
custom_runtime_estimation=None,
)
self.module_bucket_plans = module_bucket_plans
self.nodes_in_subgraph: list[list[fx.Node]] = []
self.node_users: dict[fx.Node, OrderedSet[fx.Node]] = self._collect_node_users()
self.bucketer = ManualOverlapPreservingBucketer(
graph=self.graph,
collective_info=self.collective_info,
node_ancestors=self.node_ancestors,
node_users=self.node_users,
scheduled=OrderedSet(self.graph.nodes),
)
def _identify_collectives(self) -> None:
"""Identify all collective operations."""
for node in self.nodes:
if is_wait_tensor(node):
start = node.args[0]
info = CollectiveInfo(
start_node=start,
wait_node=node,
size_bytes=0,
estimated_time_ms=0,
exposed_time_ms=0,
)
self.collective_info[start] = info
self.wait_to_start[node] = start
self.unscheduled_collectives.add(start)
def run(self) -> torch.fx.GraphModule:
"""Entry point to run the manual bucket algorithm"""
# Bucket collectives in each bucket_module
self._manual_bucket_collectives()
# Reorder collectives with last/next bucket_module
self._manual_reorder_graph()
return self.gm
def _manual_reorder_graph(self) -> None:
"""
Reorder nodes in the FX graph to enforce manual overlap dependencies.
Enforce:
- all_gather_start_i depends on all_gather_wait_(i-1)
- reduce_scatter_wait_i must happen before reduce_scatter_start_(i+1)
"""
delayed_rs_nodes: list[fx.Node] = []
overlap_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
# schedule reduce scatter normally in self._schedule
while self.ready:
_, node = heapq.heappop(self.ready)
node_type = node.meta.get("manual_bucket_node_type", "")
if node in self.scheduled:
continue
if node_type == "bucketed_reduce_scatter":
# Ensure all delayed waits execute before this reduce_scatter
for delayed in delayed_rs_nodes:
self._schedule(delayed)
overlap_deps[delayed].add(node)
delayed_rs_nodes.clear()
elif node_type == "bucketed_reduce_scatter_wait":
# Defer until next reduce_scatter
delayed_rs_nodes.append(node)
continue
self._schedule(node)
for delayed in delayed_rs_nodes:
self._schedule(delayed)
self.scheduled = OrderedSet(reversed(list(self.scheduled)))
picked_ag: list[fx.Node] = []
last_compute: Optional[fx.Node] = None
for node in self.scheduled:
node_type = node.meta.get("manual_bucket_node_type", "")
if node_type == "bucketed_all_gather":
picked_ag.append(node)
continue
if node_type == "bucketed_all_gather_wait":
# Connect corresponding all_gather_wait -> all_gather edges
if picked_ag:
for ag in picked_ag:
overlap_deps[self.bucketer.wait_to_node_map[node]].add(ag)
picked_ag.clear()
if is_compute_node(node):
last_compute = node
if last_compute is not None and not bool(
OrderedSet(picked_ag) & OrderedSet(self.node_ancestors[last_compute])
):
for ag in picked_ag:
overlap_deps[last_compute].add(ag)
_stable_topological_sort(self.graph, overlap_deps)
self.graph.lint()
def _manual_bucket_collectives(self) -> None:
"""Bucket nodes in each module_bucket from module_bucket_plans."""
self._obtain_nodes_in_subgraph()
for i, nodes in enumerate(self.nodes_in_subgraph):
self.bucketer.manual_bucket_collectives(nodes=nodes)
_stable_topological_sort(self.graph, {})
self.graph.lint()
self.nodes = list(self.graph.nodes)
self.in_degree = Counter(user for node in self.nodes for user in node.users)
def _collect_node_users(self) -> dict[fx.Node, OrderedSet[fx.Node]]:
"""Collect all users for each node."""
node_users: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet)
for node in self.nodes:
for output_node in list(node.users.keys()):
node_users[node].add(output_node)
node_users[node] |= node_users[output_node]
return node_users
def _schedule(self, node: fx.Node) -> None:
"""Schedule a node."""
assert node not in self.scheduled
assert all(n in self.scheduled for n in node.all_input_nodes)
self.scheduled.add(node)
for user in node.users:
self.in_degree[user] -= 1
if self.in_degree[user] == 0:
heapq.heappush(self.ready, ((), user))
def _obtain_nodes_in_subgraph(self) -> None:
"""
Obtain nodes in each subgraph from module_bucket_plans
"""
graph_view: Container = make_graph_view(self.graph)
for module in self.module_bucket_plans:
subgraph_view = get_subgraph_by_path(graph_view.children["Model"], module)
self.nodes_in_subgraph.append(subgraph_view)
def manual_overlap_bucketing(
gm: torch.fx.GraphModule,
module_bucket_plans: list[str],
) -> torch.fx.GraphModule:
"""Schedule nodes based on user specifications in module_bucket_plans
The manual overlapping consists of two steps:
Step 1: bucket all-gather/reduce-scatter in each module in module_bucket_plans
Step 2: reorder all-gather to overlap with last module_bucket &
reorder reduce-scatter to overlap with next module_bucket
TODO(ruisizhang123): allow users to explicitly specify which
module_bucket they want to overlap.
Args:
gm: input graph module to optimize.
module_bucket_plans: user specified FQNs
"""
# decode abbreviated FQNs to actual FQNs
module_bucket_plans = decode_module(module_bucket_plans)
overlapped_gm = ManualOverlapScheduler(gm, module_bucket_plans).run()
overlapped_gm.recompile()
return overlapped_gm

View File

@ -17,7 +17,6 @@ import time
from collections.abc import Sequence
from concurrent.futures import as_completed, ThreadPoolExecutor
from io import StringIO
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
from typing_extensions import Self
@ -2105,11 +2104,6 @@ class TritonTemplate(KernelTemplate):
"matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0),
"waves_per_eu": kwargs.get("waves_per_eu", 0),
"kpack": kwargs.get("kpack", 2),
**{
k: kwargs[k]
for k in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS
if k in kwargs
},
},
mutated_inputs=mutated_inputs,
workspace_arg=workspace_arg,
@ -2403,17 +2397,6 @@ def get_mm_log_filename() -> Optional[str]:
return mm_file_name
@functools.cache
def get_flex_attention_log_filename() -> Optional[str]:
flex_attention_file_name = os.environ.get(
"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None
)
if not flex_attention_file_name:
return None
return str(Path(flex_attention_file_name).with_suffix(".json"))
def append_to_log(filename, data):
lock_file = filename.replace(".json", ".lock")
lock = FileLock(lock_file)
@ -2624,25 +2607,6 @@ class AlgorithmSelectorCache(PersistentCache):
doesn't depend on the output layout.
"""
FLEX_ATTENTION_TUNABLE_KEYS = tuple(
dict.fromkeys(
[
"num_warps",
"num_stages",
"BLOCK_M",
"BLOCK_N",
"BLOCK_M1",
"BLOCK_N1",
"BLOCK_M2",
"BLOCK_N2",
"USE_TMA",
"kpack",
"matrix_instr_nonkdim",
"waves_per_eu",
]
)
)
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@ -3576,73 +3540,6 @@ class AlgorithmSelectorCache(PersistentCache):
)
return pruned_choices
@staticmethod
def get_flex_attention_choice_info(
choice: ChoiceCaller, timings: dict[ChoiceCaller, float]
) -> dict[str, Any]:
if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller):
return {"type": "extern", "time": timings[choice]}
assert isinstance(choice, torch._inductor.select_algorithm.TritonTemplateCaller)
info = choice.info_dict()
result = {
"type": "triton",
"time": timings[choice],
}
for key in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS:
if key in info:
result[key] = info[key]
return result
@staticmethod
def maybe_log_flex_attention_results(
name: str, input_nodes: list[ir.IRNode], timings: dict[ChoiceCaller, float]
) -> None:
flex_attention_filename = get_flex_attention_log_filename()
if not flex_attention_filename or "flex_attention" not in name:
return
if len(input_nodes) < 3:
return
query_size = input_nodes[0].get_size()
key_size = input_nodes[1].get_size()
value_size = input_nodes[2].get_size()
B = query_size[0]
Hq = query_size[1]
seq_len_q = query_size[2]
qk_head_dim = query_size[3]
Hkv = key_size[1]
seq_len_kv = key_size[2]
v_head_dim = value_size[3]
kernel_type = "backward" if "backward" in name else "forward"
dims_key = str(
(
kernel_type,
B,
Hq,
Hkv,
seq_len_q,
seq_len_kv,
qk_head_dim,
v_head_dim,
)
)
sorted_choices = sorted(timings, key=timings.__getitem__)
out_dict = {
dims_key: [
AlgorithmSelectorCache.get_flex_attention_choice_info(choice, timings)
for choice in sorted_choices
]
}
append_to_log(flex_attention_filename, out_dict)
@staticmethod
def log_results(
name: str,
@ -3653,7 +3550,6 @@ class AlgorithmSelectorCache(PersistentCache):
prescreening_elapse: Optional[float] = None,
hint_override: Optional[int] = None,
):
"""Log the autotuning results, currently only handles mm and flex"""
V.debug.log_autotuning_results(
name, input_nodes, timings, elapse, precompile_elapse
)
@ -3722,10 +3618,6 @@ class AlgorithmSelectorCache(PersistentCache):
append_to_log(mm_filename, out_dict)
AlgorithmSelectorCache.maybe_log_flex_attention_results(
name, input_nodes, timings
)
best_time = timings[best]
sys.stderr.write(f"AUTOTUNE {name}({sizes})\n")
sys.stderr.write(f"strides: {strides}\n")