mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 12:15:03 +08:00 
			
		
		
		
	Compare commits
	
		
			1 Commits
		
	
	
		
			revert-cpp
			...
			ruisi/manu
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| f31507da3e | 
| @ -129,7 +129,7 @@ function install_129 { | ||||
| } | ||||
|  | ||||
| function install_128 { | ||||
|   CUDNN_VERSION=9.8.0.87 | ||||
|   CUDNN_VERSION=9.10.2.21 | ||||
|   echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" | ||||
|   # install CUDA 12.8.1 in the same container | ||||
|   install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux | ||||
|  | ||||
| @ -272,6 +272,18 @@ def smoke_test_cuda( | ||||
|         torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) | ||||
|         print(f"Torch cuDNN version: {torch_cudnn_version}") | ||||
|  | ||||
|         torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion() | ||||
|         print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}") | ||||
|         torch_cudnn_runtime_version = tuple( | ||||
|             [int(x) for x in torch_cudnn_version.split(".")] | ||||
|         ) | ||||
|         if torch_cudnn_runtime_version != torch_cudnn_compile_version: | ||||
|             raise RuntimeError( | ||||
|                 "cuDNN runtime version doesn't match comple version. " | ||||
|                 f"Loaded: {torch_cudnn_runtime_version} " | ||||
|                 f"Expected: {torch_cudnn_compile_version}" | ||||
|             ) | ||||
|  | ||||
|         if sys.platform in ["linux", "linux2"]: | ||||
|             torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) | ||||
|             print(f"Torch nccl; version: {torch_nccl_version}") | ||||
|  | ||||
| @ -1,8 +1,3 @@ | ||||
| --- | ||||
| name: docstring | ||||
| description: Write docstrings for PyTorch functions and methods following PyTorch conventions. Use when writing or updating docstrings in PyTorch code. | ||||
| --- | ||||
| 
 | ||||
| # PyTorch Docstring Writing Guide | ||||
| 
 | ||||
| This skill describes how to write docstrings for functions and methods in the PyTorch project, following the conventions in `torch/_tensor_docs.py` and `torch/nn/functional.py`. | ||||
| @ -1,385 +0,0 @@ | ||||
| --- | ||||
| name: skill-writer | ||||
| description: Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill, or needs help with SKILL.md files, frontmatter, or skill structure. | ||||
| --- | ||||
|  | ||||
| # Skill Writer | ||||
|  | ||||
| This Skill helps you create well-structured Agent Skills for Claude Code that follow best practices and validation requirements. | ||||
|  | ||||
| ## When to use this Skill | ||||
|  | ||||
| Use this Skill when: | ||||
| - Creating a new Agent Skill | ||||
| - Writing or updating SKILL.md files | ||||
| - Designing skill structure and frontmatter | ||||
| - Troubleshooting skill discovery issues | ||||
| - Converting existing prompts or workflows into Skills | ||||
|  | ||||
| ## Instructions | ||||
|  | ||||
| ### Step 1: Determine Skill scope | ||||
|  | ||||
| First, understand what the Skill should do: | ||||
|  | ||||
| 1. **Ask clarifying questions**: | ||||
|    - What specific capability should this Skill provide? | ||||
|    - When should Claude use this Skill? | ||||
|    - What tools or resources does it need? | ||||
|    - Is this for personal use or team sharing? | ||||
|  | ||||
| 2. **Keep it focused**: One Skill = one capability | ||||
|    - Good: "PDF form filling", "Excel data analysis" | ||||
|    - Too broad: "Document processing", "Data tools" | ||||
|  | ||||
| ### Step 2: Choose Skill location | ||||
|  | ||||
| Determine where to create the Skill: | ||||
|  | ||||
| **Personal Skills** (`~/.claude/skills/`): | ||||
| - Individual workflows and preferences | ||||
| - Experimental Skills | ||||
| - Personal productivity tools | ||||
|  | ||||
| **Project Skills** (`.claude/skills/`): | ||||
| - Team workflows and conventions | ||||
| - Project-specific expertise | ||||
| - Shared utilities (committed to git) | ||||
|  | ||||
| ### Step 3: Create Skill structure | ||||
|  | ||||
| Create the directory and files: | ||||
|  | ||||
| ```bash | ||||
| # Personal | ||||
| mkdir -p ~/.claude/skills/skill-name | ||||
|  | ||||
| # Project | ||||
| mkdir -p .claude/skills/skill-name | ||||
| ``` | ||||
|  | ||||
| For multi-file Skills: | ||||
| ``` | ||||
| skill-name/ | ||||
| ├── SKILL.md (required) | ||||
| ├── reference.md (optional) | ||||
| ├── examples.md (optional) | ||||
| ├── scripts/ | ||||
| │   └── helper.py (optional) | ||||
| └── templates/ | ||||
|     └── template.txt (optional) | ||||
| ``` | ||||
|  | ||||
| ### Step 4: Write SKILL.md frontmatter | ||||
|  | ||||
| Create YAML frontmatter with required fields: | ||||
|  | ||||
| ```yaml | ||||
| --- | ||||
| name: skill-name | ||||
| description: Brief description of what this does and when to use it | ||||
| --- | ||||
| ``` | ||||
|  | ||||
| **Field requirements**: | ||||
|  | ||||
| - **name**: | ||||
|   - Lowercase letters, numbers, hyphens only | ||||
|   - Max 64 characters | ||||
|   - Must match directory name | ||||
|   - Good: `pdf-processor`, `git-commit-helper` | ||||
|   - Bad: `PDF_Processor`, `Git Commits!` | ||||
|  | ||||
| - **description**: | ||||
|   - Max 1024 characters | ||||
|   - Include BOTH what it does AND when to use it | ||||
|   - Use specific trigger words users would say | ||||
|   - Mention file types, operations, and context | ||||
|  | ||||
| **Optional frontmatter fields**: | ||||
|  | ||||
| - **allowed-tools**: Restrict tool access (comma-separated list) | ||||
|   ```yaml | ||||
|   allowed-tools: Read, Grep, Glob | ||||
|   ``` | ||||
|   Use for: | ||||
|   - Read-only Skills | ||||
|   - Security-sensitive workflows | ||||
|   - Limited-scope operations | ||||
|  | ||||
| ### Step 5: Write effective descriptions | ||||
|  | ||||
| The description is critical for Claude to discover your Skill. | ||||
|  | ||||
| **Formula**: `[What it does] + [When to use it] + [Key triggers]` | ||||
|  | ||||
| **Examples**: | ||||
|  | ||||
| ✅ **Good**: | ||||
| ```yaml | ||||
| description: Extract text and tables from PDF files, fill forms, merge documents. Use when working with PDF files or when the user mentions PDFs, forms, or document extraction. | ||||
| ``` | ||||
|  | ||||
| ✅ **Good**: | ||||
| ```yaml | ||||
| description: Analyze Excel spreadsheets, create pivot tables, and generate charts. Use when working with Excel files, spreadsheets, or analyzing tabular data in .xlsx format. | ||||
| ``` | ||||
|  | ||||
| ❌ **Too vague**: | ||||
| ```yaml | ||||
| description: Helps with documents | ||||
| description: For data analysis | ||||
| ``` | ||||
|  | ||||
| **Tips**: | ||||
| - Include specific file extensions (.pdf, .xlsx, .json) | ||||
| - Mention common user phrases ("analyze", "extract", "generate") | ||||
| - List concrete operations (not generic verbs) | ||||
| - Add context clues ("Use when...", "For...") | ||||
|  | ||||
| ### Step 6: Structure the Skill content | ||||
|  | ||||
| Use clear Markdown sections: | ||||
|  | ||||
| ```markdown | ||||
| # Skill Name | ||||
|  | ||||
| Brief overview of what this Skill does. | ||||
|  | ||||
| ## Quick start | ||||
|  | ||||
| Provide a simple example to get started immediately. | ||||
|  | ||||
| ## Instructions | ||||
|  | ||||
| Step-by-step guidance for Claude: | ||||
| 1. First step with clear action | ||||
| 2. Second step with expected outcome | ||||
| 3. Handle edge cases | ||||
|  | ||||
| ## Examples | ||||
|  | ||||
| Show concrete usage examples with code or commands. | ||||
|  | ||||
| ## Best practices | ||||
|  | ||||
| - Key conventions to follow | ||||
| - Common pitfalls to avoid | ||||
| - When to use vs. not use | ||||
|  | ||||
| ## Requirements | ||||
|  | ||||
| List any dependencies or prerequisites: | ||||
| ```bash | ||||
| pip install package-name | ||||
| ``` | ||||
|  | ||||
| ## Advanced usage | ||||
|  | ||||
| For complex scenarios, see [reference.md](reference.md). | ||||
| ``` | ||||
|  | ||||
| ### Step 7: Add supporting files (optional) | ||||
|  | ||||
| Create additional files for progressive disclosure: | ||||
|  | ||||
| **reference.md**: Detailed API docs, advanced options | ||||
| **examples.md**: Extended examples and use cases | ||||
| **scripts/**: Helper scripts and utilities | ||||
| **templates/**: File templates or boilerplate | ||||
|  | ||||
| Reference them from SKILL.md: | ||||
| ```markdown | ||||
| For advanced usage, see [reference.md](reference.md). | ||||
|  | ||||
| Run the helper script: | ||||
| \`\`\`bash | ||||
| python scripts/helper.py input.txt | ||||
| \`\`\` | ||||
| ``` | ||||
|  | ||||
| ### Step 8: Validate the Skill | ||||
|  | ||||
| Check these requirements: | ||||
|  | ||||
| ✅ **File structure**: | ||||
| - [ ] SKILL.md exists in correct location | ||||
| - [ ] Directory name matches frontmatter `name` | ||||
|  | ||||
| ✅ **YAML frontmatter**: | ||||
| - [ ] Opening `---` on line 1 | ||||
| - [ ] Closing `---` before content | ||||
| - [ ] Valid YAML (no tabs, correct indentation) | ||||
| - [ ] `name` follows naming rules | ||||
| - [ ] `description` is specific and < 1024 chars | ||||
|  | ||||
| ✅ **Content quality**: | ||||
| - [ ] Clear instructions for Claude | ||||
| - [ ] Concrete examples provided | ||||
| - [ ] Edge cases handled | ||||
| - [ ] Dependencies listed (if any) | ||||
|  | ||||
| ✅ **Testing**: | ||||
| - [ ] Description matches user questions | ||||
| - [ ] Skill activates on relevant queries | ||||
| - [ ] Instructions are clear and actionable | ||||
|  | ||||
| ### Step 9: Test the Skill | ||||
|  | ||||
| 1. **Restart Claude Code** (if running) to load the Skill | ||||
|  | ||||
| 2. **Ask relevant questions** that match the description: | ||||
|    ``` | ||||
|    Can you help me extract text from this PDF? | ||||
|    ``` | ||||
|  | ||||
| 3. **Verify activation**: Claude should use the Skill automatically | ||||
|  | ||||
| 4. **Check behavior**: Confirm Claude follows the instructions correctly | ||||
|  | ||||
| ### Step 10: Debug if needed | ||||
|  | ||||
| If Claude doesn't use the Skill: | ||||
|  | ||||
| 1. **Make description more specific**: | ||||
|    - Add trigger words | ||||
|    - Include file types | ||||
|    - Mention common user phrases | ||||
|  | ||||
| 2. **Check file location**: | ||||
|    ```bash | ||||
|    ls ~/.claude/skills/skill-name/SKILL.md | ||||
|    ls .claude/skills/skill-name/SKILL.md | ||||
|    ``` | ||||
|  | ||||
| 3. **Validate YAML**: | ||||
|    ```bash | ||||
|    cat SKILL.md | head -n 10 | ||||
|    ``` | ||||
|  | ||||
| 4. **Run debug mode**: | ||||
|    ```bash | ||||
|    claude --debug | ||||
|    ``` | ||||
|  | ||||
| ## Common patterns | ||||
|  | ||||
| ### Read-only Skill | ||||
|  | ||||
| ```yaml | ||||
| --- | ||||
| name: code-reader | ||||
| description: Read and analyze code without making changes. Use for code review, understanding codebases, or documentation. | ||||
| allowed-tools: Read, Grep, Glob | ||||
| --- | ||||
| ``` | ||||
|  | ||||
| ### Script-based Skill | ||||
|  | ||||
| ```yaml | ||||
| --- | ||||
| name: data-processor | ||||
| description: Process CSV and JSON data files with Python scripts. Use when analyzing data files or transforming datasets. | ||||
| --- | ||||
|  | ||||
| # Data Processor | ||||
|  | ||||
| ## Instructions | ||||
|  | ||||
| 1. Use the processing script: | ||||
| \`\`\`bash | ||||
| python scripts/process.py input.csv --output results.json | ||||
| \`\`\` | ||||
|  | ||||
| 2. Validate output with: | ||||
| \`\`\`bash | ||||
| python scripts/validate.py results.json | ||||
| \`\`\` | ||||
| ``` | ||||
|  | ||||
| ### Multi-file Skill with progressive disclosure | ||||
|  | ||||
| ```yaml | ||||
| --- | ||||
| name: api-designer | ||||
| description: Design REST APIs following best practices. Use when creating API endpoints, designing routes, or planning API architecture. | ||||
| --- | ||||
|  | ||||
| # API Designer | ||||
|  | ||||
| Quick start: See [examples.md](examples.md) | ||||
|  | ||||
| Detailed reference: See [reference.md](reference.md) | ||||
|  | ||||
| ## Instructions | ||||
|  | ||||
| 1. Gather requirements | ||||
| 2. Design endpoints (see examples.md) | ||||
| 3. Document with OpenAPI spec | ||||
| 4. Review against best practices (see reference.md) | ||||
| ``` | ||||
|  | ||||
| ## Best practices for Skill authors | ||||
|  | ||||
| 1. **One Skill, one purpose**: Don't create mega-Skills | ||||
| 2. **Specific descriptions**: Include trigger words users will say | ||||
| 3. **Clear instructions**: Write for Claude, not humans | ||||
| 4. **Concrete examples**: Show real code, not pseudocode | ||||
| 5. **List dependencies**: Mention required packages in description | ||||
| 6. **Test with teammates**: Verify activation and clarity | ||||
| 7. **Version your Skills**: Document changes in content | ||||
| 8. **Use progressive disclosure**: Put advanced details in separate files | ||||
|  | ||||
| ## Validation checklist | ||||
|  | ||||
| Before finalizing a Skill, verify: | ||||
|  | ||||
| - [ ] Name is lowercase, hyphens only, max 64 chars | ||||
| - [ ] Description is specific and < 1024 chars | ||||
| - [ ] Description includes "what" and "when" | ||||
| - [ ] YAML frontmatter is valid | ||||
| - [ ] Instructions are step-by-step | ||||
| - [ ] Examples are concrete and realistic | ||||
| - [ ] Dependencies are documented | ||||
| - [ ] File paths use forward slashes | ||||
| - [ ] Skill activates on relevant queries | ||||
| - [ ] Claude follows instructions correctly | ||||
|  | ||||
| ## Troubleshooting | ||||
|  | ||||
| **Skill doesn't activate**: | ||||
| - Make description more specific with trigger words | ||||
| - Include file types and operations in description | ||||
| - Add "Use when..." clause with user phrases | ||||
|  | ||||
| **Multiple Skills conflict**: | ||||
| - Make descriptions more distinct | ||||
| - Use different trigger words | ||||
| - Narrow the scope of each Skill | ||||
|  | ||||
| **Skill has errors**: | ||||
| - Check YAML syntax (no tabs, proper indentation) | ||||
| - Verify file paths (use forward slashes) | ||||
| - Ensure scripts have execute permissions | ||||
| - List all dependencies | ||||
|  | ||||
| ## Examples | ||||
|  | ||||
| See the documentation for complete examples: | ||||
| - Simple single-file Skill (commit-helper) | ||||
| - Skill with tool permissions (code-reviewer) | ||||
| - Multi-file Skill (pdf-processing) | ||||
|  | ||||
| ## Output format | ||||
|  | ||||
| When creating a Skill, I will: | ||||
|  | ||||
| 1. Ask clarifying questions about scope and requirements | ||||
| 2. Suggest a Skill name and location | ||||
| 3. Create the SKILL.md file with proper frontmatter | ||||
| 4. Include clear instructions and examples | ||||
| 5. Add supporting files if needed | ||||
| 6. Provide testing instructions | ||||
| 7. Validate against all requirements | ||||
|  | ||||
| The result will be a complete, working Skill that follows all best practices and validation rules. | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,171 +0,0 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/core/Tensor.h> | ||||
|  | ||||
| namespace at::native { | ||||
|  | ||||
| using at::blas::ScalingType; | ||||
| using at::blas::SwizzleType; | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| // TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 | ||||
| c10::MaybeOwned<Tensor> inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { | ||||
|   if (resolve_conj && tensor.is_conj()) { | ||||
|     return c10::MaybeOwned<Tensor>::owned(tensor.resolve_conj()); | ||||
|   } else { | ||||
|     return c10::MaybeOwned<Tensor>::borrowed(tensor); | ||||
|   } | ||||
| } | ||||
|  | ||||
| c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { | ||||
|   if (tensor.is_non_overlapping_and_dense()) { // common case | ||||
|       transpose_tensor = tensor.is_contiguous(); | ||||
|       return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); | ||||
|   } | ||||
|   IntArrayRef tensor_strides = tensor.strides(); | ||||
|   IntArrayRef tensor_sizes = tensor.sizes(); | ||||
|   if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) { | ||||
|     transpose_tensor = false; | ||||
|     return resolve_conj_if_indicated(tensor, !transpose_result); | ||||
|   } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) { | ||||
|     transpose_tensor = true; | ||||
|     return resolve_conj_if_indicated(tensor, transpose_result); | ||||
|   } else { | ||||
|     transpose_tensor = true; | ||||
|     return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous)); | ||||
|   } | ||||
| } | ||||
|  | ||||
| c10::MaybeOwned<Tensor> inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { | ||||
|   if (tensor.is_non_overlapping_and_dense()) { // common case | ||||
|       transpose_tensor = tensor.is_contiguous(); | ||||
|       return resolve_conj_if_indicated(tensor, true); | ||||
|   } | ||||
|  | ||||
|   IntArrayRef tensor_strides = tensor.strides(); | ||||
|   IntArrayRef tensor_sizes = tensor.sizes(); | ||||
|   if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) { | ||||
|     transpose_tensor = false; | ||||
|     return resolve_conj_if_indicated(tensor, true); | ||||
|   } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) { | ||||
|     transpose_tensor = true; | ||||
|     return resolve_conj_if_indicated(tensor, true); | ||||
|   } else { | ||||
|     transpose_tensor = true; | ||||
|     return c10::MaybeOwned<Tensor>::owned(tensor.clone(at::MemoryFormat::Contiguous)); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // namespace | ||||
|  | ||||
| /** | ||||
|  * @brief Prepares matrices for CUBLAS operation | ||||
|  * | ||||
|  * This constructor prepares tensors for CUBLAS | ||||
|  * The main difference is that PyTorch uses row-major as the default and | ||||
|  * CUBLAS expects column-major. | ||||
|  * | ||||
|  * @details | ||||
|  * To enable row-major output while using CUBLAS, | ||||
|  * we use the mathematical identity that (A × B)^T = B^T × A^T. | ||||
|  * | ||||
|  * Transpose in this context refers to Cublas's(Fortran) definition of transpose (row-major) | ||||
|  * T = row-major, N = col-major | ||||
|  * | ||||
|  * Example: | ||||
|  * For matrices A (M×K)(row-major) and B (K×N)(row-major): | ||||
|  *   - Standard multiplication: A × B = (M×K) × (K×N) = M×N result (row-major) | ||||
|  *   - Using our transpose trick: (B^T × A^T) = (N×K)(T) × (K×M)(T) = N×M(N) | ||||
|  *   - However, since the output form cublas is column-major this is | ||||
|  *   - equivalent to an output of size MxN row-major as expected | ||||
|  * | ||||
|  * The transpose flags are derived from the layouts of the passed in tensors | ||||
|  * | ||||
|  * If the operands are in packed float4 format, `k`, `lda` and `ldb` are adjusted | ||||
|  * to their unpacked values to match what cuBLAS expects. | ||||
|  * | ||||
|  * @param mat1 First input matrix | ||||
|  * @param mat2 Second input matrix | ||||
|  * @param c Output matrix (result) | ||||
|  * @param scale_a Optional scaling factor for first matrix | ||||
|  * @param scale_b Optional scaling factor for second matrix | ||||
|  * @param scale_result Optional scaling factor for result | ||||
|  */ | ||||
| struct cublasCommonArgs { | ||||
|   cublasCommonArgs( | ||||
|       const Tensor& mat1, | ||||
|       const Tensor& mat2, | ||||
|       Tensor& c, | ||||
|       const std::optional<Tensor>& scale_a = std::nullopt, | ||||
|       const std::optional<Tensor>& scale_b = std::nullopt, | ||||
|       const std::optional<Tensor>& scale_result = std::nullopt, | ||||
|       const std::optional<ScalingType>& scaling_choice_a = std::nullopt, | ||||
|       const std::optional<ScalingType>& scaling_choice_b = std::nullopt) { | ||||
|     bool transpose_result = false, transpose_a = false, transpose_b = false; | ||||
|     result = prepare_matrix_for_cublas(c, transpose_result); | ||||
|     mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); | ||||
|     matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result); | ||||
|  | ||||
|     // Handle scale tensors if provided | ||||
|     if (scale_a && scale_b) { | ||||
|       // By default since we return in row-major we run the gemm | ||||
|       // as B.T @ A.T, check transpose_result to determine if we flip the scales | ||||
|       scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); | ||||
|       scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); | ||||
|       scaling_mata_type = transpose_result ? scaling_choice_b : scaling_choice_a; | ||||
|       scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); | ||||
|       scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); | ||||
|       scaling_matb_type = transpose_result ? scaling_choice_a : scaling_choice_b; | ||||
|     } | ||||
|  | ||||
|     if (scale_result) { | ||||
|       scale_result_ptr = scale_result->data_ptr(); | ||||
|       scale_result_dtype = scale_result->scalar_type(); | ||||
|     } | ||||
|  | ||||
|     // Update transpose flags | ||||
|     if (transpose_result) { | ||||
|       transpose_a = !transpose_a; | ||||
|       transpose_b = !transpose_b; | ||||
|     } | ||||
|  | ||||
|     auto sizes_a = mata->sizes(); | ||||
|     auto sizes_b = matb->sizes(); | ||||
|  | ||||
|     m = sizes_a[transpose_result ? 1 : 0]; | ||||
|     k = sizes_a[transpose_result ? 0 : 1]; | ||||
|     n = sizes_b[transpose_result ? 0 : 1]; | ||||
|     lda = mata->stride((transpose_a == transpose_result) ? 1 : 0); | ||||
|     ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0); | ||||
|     result_ld = result->stride(transpose_result ? 0 : 1); | ||||
|     transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n'; | ||||
|     transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n'; | ||||
|  | ||||
|     // cuBLAS expects unpacked values of `k`, `lda` and `ldb`, adjust for 4x2 packing | ||||
|     // if the gemm operands are in packed float4 | ||||
|     if (mat1.dtype() == at::kFloat4_e2m1fn_x2 && mat2.dtype() == at::kFloat4_e2m1fn_x2) { | ||||
|       k = k * 2; | ||||
|       lda = lda * 2; | ||||
|       ldb = ldb * 2; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // Matrix members | ||||
|   char transa, transb; | ||||
|   int64_t m, n, k; | ||||
|   int64_t lda, ldb, result_ld; | ||||
|   c10::MaybeOwned<Tensor> mata, matb, result; | ||||
|  | ||||
|   // Scale members | ||||
|   void* scale_mata_ptr = nullptr; | ||||
|   void* scale_matb_ptr = nullptr; | ||||
|   void* scale_result_ptr = nullptr; | ||||
|   std::optional<c10::ScalarType> scale_mata_dtype; | ||||
|   std::optional<ScalingType> scaling_mata_type; | ||||
|   std::optional<c10::ScalarType> scale_matb_dtype; | ||||
|   std::optional<ScalingType> scaling_matb_type; | ||||
|   std::optional<c10::ScalarType> scale_result_dtype; | ||||
| }; | ||||
|  | ||||
| } // namespace at::native | ||||
| @ -279,7 +279,6 @@ class SymmetricMemoryTest(MultiProcContinuousTest): | ||||
| # MultiProcContinuousTest will skip all the following tests if a test fails ( | ||||
| # we should fix this too). We still want to get the test signals for the core | ||||
| # symmetric memory APIs when Async TP ops fail. | ||||
| @skip_if_rocm_multiprocess  # AsyncTP is not yet supported on ROCm | ||||
| @instantiate_parametrized_tests | ||||
| @requires_cuda_p2p_access() | ||||
| class AsyncTPTest(MultiProcContinuousTest): | ||||
|  | ||||
| @ -2,11 +2,8 @@ | ||||
| # flake8: noqa: B950 | ||||
|  | ||||
| import functools | ||||
| import json | ||||
| import os | ||||
| import random | ||||
| import string | ||||
| import tempfile | ||||
| import unittest | ||||
| import warnings | ||||
| from collections import namedtuple | ||||
| @ -7048,120 +7045,6 @@ class TestLearnableBiases(InductorTestCase): | ||||
|     def test_flex_attention_with_dynamic_max_autotune_graph_partition(self, device): | ||||
|         self._test_flex_attention_with_dynamic_max_autotune(device) | ||||
|  | ||||
|     @skip_on_cpu | ||||
|     def test_flex_attention_logging(self, device): | ||||
|         with tempfile.TemporaryDirectory() as tmpdir: | ||||
|             log_file = os.path.join(tmpdir, "flex_attention_configs") | ||||
|  | ||||
|             with patch.dict( | ||||
|                 os.environ, {"TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE": log_file} | ||||
|             ): | ||||
|                 query = torch.randn( | ||||
|                     1, | ||||
|                     2, | ||||
|                     128, | ||||
|                     64, | ||||
|                     device=device, | ||||
|                     dtype=torch.float16, | ||||
|                     requires_grad=True, | ||||
|                 ) | ||||
|                 key = torch.randn( | ||||
|                     1, | ||||
|                     2, | ||||
|                     128, | ||||
|                     64, | ||||
|                     device=device, | ||||
|                     dtype=torch.float16, | ||||
|                     requires_grad=True, | ||||
|                 ) | ||||
|                 value = torch.randn( | ||||
|                     1, | ||||
|                     2, | ||||
|                     128, | ||||
|                     64, | ||||
|                     device=device, | ||||
|                     dtype=torch.float16, | ||||
|                     requires_grad=True, | ||||
|                 ) | ||||
|  | ||||
|                 def score_mod(score, b, h, q_idx, kv_idx): | ||||
|                     return score * 2 | ||||
|  | ||||
|                 def causal_mask(b, h, q_idx, kv_idx): | ||||
|                     return q_idx >= kv_idx | ||||
|  | ||||
|                 block_mask = torch.compile(create_block_mask)( | ||||
|                     causal_mask, 1, 1, 128, 128, device=device | ||||
|                 ) | ||||
|  | ||||
|                 compiled_flex = torch.compile( | ||||
|                     flex_attention, mode="max-autotune-no-cudagraphs" | ||||
|                 ) | ||||
|  | ||||
|                 out = compiled_flex( | ||||
|                     query=query, | ||||
|                     key=key, | ||||
|                     value=value, | ||||
|                     score_mod=score_mod, | ||||
|                     block_mask=block_mask, | ||||
|                 ) | ||||
|  | ||||
|                 out.sum().backward() | ||||
|  | ||||
|                 json_file = log_file + ".json" | ||||
|                 self.assertTrue( | ||||
|                     os.path.exists(json_file), f"Log file {json_file} was not created" | ||||
|                 ) | ||||
|  | ||||
|                 with open(json_file) as f: | ||||
|                     log_data = json.load(f) | ||||
|  | ||||
|                 self.assertIsInstance(log_data, list) | ||||
|                 self.assertEqual(len(log_data), 2) | ||||
|  | ||||
|                 keys_seen = [next(iter(entry.keys())) for entry in log_data] | ||||
|  | ||||
|                 expected_fwd_key = "('forward', 1, 2, 2, 128, 128, 64, 64)" | ||||
|                 expected_bwd_key = "('backward', 1, 2, 2, 128, 128, 64, 64)" | ||||
|  | ||||
|                 self.assertIn(expected_fwd_key, keys_seen) | ||||
|                 self.assertIn(expected_bwd_key, keys_seen) | ||||
|  | ||||
|                 for entry in log_data: | ||||
|                     self.assertIsInstance(entry, dict) | ||||
|                     self.assertEqual(len(entry), 1) | ||||
|  | ||||
|                     dims_key = next(iter(entry.keys())) | ||||
|                     choices = entry[dims_key] | ||||
|  | ||||
|                     kernel_type = eval(dims_key)[0] | ||||
|  | ||||
|                     self.assertIsInstance(choices, list) | ||||
|                     self.assertGreater(len(choices), 0) | ||||
|  | ||||
|                     for i, choice in enumerate(choices): | ||||
|                         self.assertIn("type", choice) | ||||
|                         self.assertIn("time", choice) | ||||
|  | ||||
|                         if choice["type"] == "triton": | ||||
|                             self.assertIn("num_warps", choice) | ||||
|                             self.assertIn("num_stages", choice) | ||||
|  | ||||
|                             if kernel_type == "forward": | ||||
|                                 self.assertIn("BLOCK_M", choice) | ||||
|                                 self.assertIn("BLOCK_N", choice) | ||||
|                                 self.assertNotIn("BLOCK_M1", choice) | ||||
|                             elif kernel_type == "backward": | ||||
|                                 self.assertIn("BLOCK_M1", choice) | ||||
|                                 self.assertIn("BLOCK_N1", choice) | ||||
|                                 self.assertIn("BLOCK_M2", choice) | ||||
|                                 self.assertIn("BLOCK_N2", choice) | ||||
|                                 self.assertNotIn("BLOCK_M", choice) | ||||
|                                 self.assertNotIn("BLOCK_N", choice) | ||||
|  | ||||
|                         if i > 0: | ||||
|                             self.assertLessEqual(choices[0]["time"], choice["time"]) | ||||
|  | ||||
|     @skip_on_cpu | ||||
|     def test_inspect_bug(self, device): | ||||
|         # https://github.com/pytorch/pytorch/issues/139374 | ||||
|  | ||||
| @ -445,7 +445,7 @@ use_numpy_random_stream = False | ||||
| enable_cpp_guard_manager = True | ||||
|  | ||||
| # Use C++ guard manager for symbolic shapes | ||||
| enable_cpp_symbolic_shape_guards = False | ||||
| enable_cpp_symbolic_shape_guards = not is_fbcode() | ||||
|  | ||||
| # Enable tracing through contextlib.contextmanager | ||||
| enable_trace_contextlib = True | ||||
|  | ||||
| @ -117,9 +117,9 @@ def bucket_reduce_scatter( | ||||
|  | ||||
|  | ||||
| def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:  # type: ignore[arg-type] | ||||
|     return ( | ||||
|         node.op == "call_function" | ||||
|         and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default | ||||
|     return node.op == "call_function" and ( | ||||
|         node.target == torch.ops._c10d_functional.all_gather_into_tensor.default | ||||
|         or node.target == torch.ops._c10d_functional.all_gather_into_tensor_out.default | ||||
|     ) | ||||
|  | ||||
|  | ||||
|  | ||||
							
								
								
									
										583
									
								
								torch/_inductor/fx_passes/overlap_manual_scheduling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										583
									
								
								torch/_inductor/fx_passes/overlap_manual_scheduling.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,583 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import heapq | ||||
| import itertools | ||||
| import re | ||||
| from collections import Counter, defaultdict | ||||
| from typing import Any, Callable, Optional, Union | ||||
|  | ||||
| import torch | ||||
| import torch.fx as fx | ||||
| from torch._dynamo.graph_deduplication import _stable_topological_sort | ||||
| from torch._inductor.fx_passes.bucketing import ( | ||||
|     is_all_gather_into_tensor as is_all_gather, | ||||
|     is_reduce_scatter_tensor as is_reduce_scatter, | ||||
|     is_wait_tensor, | ||||
|     merge_all_gather_bucket, | ||||
|     merge_reduce_scatter_bucket, | ||||
| ) | ||||
| from torch._inductor.fx_passes.overlap_preserving_bucketer import ( | ||||
|     bucket_key, | ||||
|     OverlapPreservingBucketer, | ||||
| ) | ||||
| from torch._inductor.fx_passes.overlap_scheduling import ( | ||||
|     CollectiveInfo, | ||||
|     is_compute_node, | ||||
|     OverlapScheduler, | ||||
| ) | ||||
| from torch.utils._ordered_set import OrderedSet | ||||
|  | ||||
|  | ||||
| def _get_module_stack(node: fx.Node) -> list[tuple[str, type[Any]]]: | ||||
|     if node.meta.get("nn_module_stack", "") == "": | ||||
|         if node.meta.get("fwd_nn_module_stack", "") != "": | ||||
|             return list(node.meta["fwd_nn_module_stack"].values()) | ||||
|         return [] | ||||
|     return list(node.meta["nn_module_stack"].values()) | ||||
|  | ||||
|  | ||||
| def _addindent(s_: str, num_spaces: int) -> str: | ||||
|     s: list[str] = s_.split("\n") | ||||
|     # don't do anything for single-line stuff | ||||
|     if len(s) == 1: | ||||
|         return s_ | ||||
|     first: str = s.pop(0) | ||||
|     s: list[str] = [(num_spaces * " ") + line for line in s] | ||||
|     joint_s: str = "\n".join(s) | ||||
|     joint_s = first + "\n" + joint_s | ||||
|     return joint_s | ||||
|  | ||||
|  | ||||
| class Container: | ||||
|     def __init__(self, name: str, klass: type[Any]) -> None: | ||||
|         self.name: str = name | ||||
|         self.klass: type[Any] = klass | ||||
|         self.data: list[fx.Node] = [] | ||||
|         self.unique_nodes: OrderedSet[fx.Node] = OrderedSet() | ||||
|         self.children: dict[str, Container] = {} | ||||
|  | ||||
|     def add(self, data: fx.Node) -> None: | ||||
|         if data not in self.unique_nodes: | ||||
|             self.data.append(data) | ||||
|             self.unique_nodes.add(data) | ||||
|  | ||||
|     def get_child( | ||||
|         self, module_stack: str, klass: Optional[type[Any]] = None | ||||
|     ) -> Container: | ||||
|         if module_stack not in self.children: | ||||
|             new_stack = Container(module_stack, klass or self.klass) | ||||
|             self.children[module_stack] = new_stack | ||||
|         return self.children[module_stack] | ||||
|  | ||||
|     def __getitem__(self, name: str) -> Container: | ||||
|         return self.children[name] | ||||
|  | ||||
|     def __getattr__(self, name: str) -> Container: | ||||
|         return self.children[name] | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         child_lines: list[str] = [] | ||||
|         for name, child in self.children.items(): | ||||
|             mod_str = repr(child) | ||||
|             mod_str = _addindent(mod_str, 2) | ||||
|             child_lines.append(f"({name}): {mod_str}") | ||||
|         main_str = f"{self.klass.__name__}(" | ||||
|         if child_lines: | ||||
|             main_str += "\n  " + "\n  ".join(child_lines) + "\n" | ||||
|         main_str += ")" | ||||
|         return main_str | ||||
|  | ||||
|     def graph_view(self) -> fx.Graph: | ||||
|         return _make_subgraph(self.data) | ||||
|  | ||||
|  | ||||
| def _clean_stack_name(val: str) -> str: | ||||
|     name: str = ( | ||||
|         val.replace("L['self']", "Model") | ||||
|         .replace("_modules['", "") | ||||
|         .replace("['", ".") | ||||
|         .replace("']", "") | ||||
|     ) | ||||
|     return name | ||||
|  | ||||
|  | ||||
| def _find_key_nodes(nodes: list[fx.Node]) -> tuple[list[fx.Node], list[fx.Node]]: | ||||
|     root = [] | ||||
|     outputs = [] | ||||
|     nodes_set = OrderedSet(nodes) | ||||
|     for node in nodes: | ||||
|         for x in node.all_input_nodes: | ||||
|             if x not in nodes_set: | ||||
|                 root.append(x) | ||||
|         if all(x not in nodes_set for x in node.users): | ||||
|             outputs.append(node) | ||||
|     return root, outputs | ||||
|  | ||||
|  | ||||
| def _make_subgraph(nodes: list[fx.Node]) -> fx.Graph: | ||||
|     placeholders, outputs = _find_key_nodes(nodes) | ||||
|  | ||||
|     new_graph = torch.fx.Graph() | ||||
|     env: dict[fx.Node, fx.Node] = {} | ||||
|  | ||||
|     def env_lookup(x: fx.Node) -> fx.Node: | ||||
|         assert x in env, f"Dependent node {x} not in env when creating downstream node" | ||||
|         return env[x] | ||||
|  | ||||
|     def node_copy( | ||||
|         node: fx.Node, arg_transform: Callable[[fx.Node], fx.Node] | ||||
|     ) -> fx.Node: | ||||
|         if node not in env: | ||||
|             new_node = new_graph.node_copy(node, arg_transform=arg_transform) | ||||
|             env[node] = new_node | ||||
|         else: | ||||
|             new_node = env[node] | ||||
|         return new_node | ||||
|  | ||||
|     for node in placeholders: | ||||
|         env[node] = new_graph.placeholder(node.name) | ||||
|  | ||||
|     for node in nodes: | ||||
|         if node in placeholders: | ||||
|             continue | ||||
|         else: | ||||
|             new_node = node_copy(node, env_lookup) | ||||
|             new_node.meta = node.meta.copy() | ||||
|  | ||||
|     out_node = [env[x] for x in outputs] | ||||
|     new_graph.output(out_node) | ||||
|     return new_graph | ||||
|  | ||||
|  | ||||
| def _is_root(stack: str) -> bool: | ||||
|     return stack == "" | ||||
|  | ||||
|  | ||||
| def make_graph_view(graph: fx.Graph) -> Container: | ||||
|     """ | ||||
|     Code from: https://github.com/meta-pytorch/autoparallel/pull/158 | ||||
|  | ||||
|     Make a graph view from the fx.Graph. This is a tree structure that | ||||
|     represents the module hierarchy of the graph, and enables us to | ||||
|     easily find the nodes that belong to each module, and gives a slightly | ||||
|     easier way of visualize different parts of the graph by extracting | ||||
|     subgraphs that belong to a particular module FQN. | ||||
|  | ||||
|     For example, if we have the following model with module hierarchy: | ||||
|  | ||||
|     Transformer( | ||||
|         (tok_embeddings): Embedding(128256, 4096) | ||||
|         (layers): ModuleDict( | ||||
|             (0): TransformerBlock( | ||||
|             (attention): Attention( | ||||
|                 (wq): Linear(in_features=4096, out_features=4096, bias=False) | ||||
|                 (wk): Linear(in_features=4096, out_features=1024, bias=False) | ||||
|                 (wv): Linear(in_features=4096, out_features=1024, bias=False) | ||||
|                 (wo): Linear(in_features=4096, out_features=4096, bias=False) | ||||
|                 (sdpa): ScaledDotProductAttention() | ||||
|             ) | ||||
|             (feed_forward): FeedForward( | ||||
|                 (w1): Linear(in_features=4096, out_features=14336, bias=False) | ||||
|                 (w2): Linear(in_features=14336, out_features=4096, bias=False) | ||||
|                 (w3): Linear(in_features=4096, out_features=14336, bias=False) | ||||
|             ) | ||||
|             (attention_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) | ||||
|             (ffn_norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) | ||||
|             ) | ||||
|         ) | ||||
|         (norm): RMSNorm((4096,), eps=1e-05, elementwise_affine=True) | ||||
|         (output): Linear(in_features=4096, out_features=128256, bias=False) | ||||
|     ) | ||||
|  | ||||
|     Then we can get a GraphView for the fx.Graph that enables us to do | ||||
|  | ||||
|     graph_view = make_graph_view(graph) | ||||
|     subgraph = graph_view.layers["0"].attention.graph_view() | ||||
|  | ||||
|     where subgraph is a fx.Graph that contains all the nodes that belong to | ||||
|     Transformer.layers['0'].attention, and whose inputs are all inputs to this | ||||
|     region of the graph, and whose outputs are all outputs of this region of | ||||
|     the graph. This returns a new graph with new nodes, so we shouldn't use it | ||||
|     for graph manipulations, but it is useful to visualize what a particular | ||||
|     part of a larger graph looks like. | ||||
|  | ||||
|     Additionally, you can also query the original nodes in that region with | ||||
|     `graph_view.layers['0'].attention.data`, which returns a list of all the | ||||
|     nodes that belong to Transformer.layers['0'].attention. | ||||
|     """ | ||||
|     nodes: list[fx.Node] = list(graph.nodes) | ||||
|     nodes_by_module_stack_root: Container | None = None | ||||
|     for node in nodes: | ||||
|         for module_stack, module_class in _get_module_stack(node): | ||||
|             module_stack = _clean_stack_name(module_stack) | ||||
|             nodes_by_module_stack: Container | None = nodes_by_module_stack_root | ||||
|             for name in module_stack.split("."): | ||||
|                 if nodes_by_module_stack is None: | ||||
|                     nodes_by_module_stack = Container(name, module_class) | ||||
|                     nodes_by_module_stack_root = nodes_by_module_stack | ||||
|                 if _is_root(module_stack): | ||||
|                     new_stack: Container = nodes_by_module_stack | ||||
|                 else: | ||||
|                     new_stack = nodes_by_module_stack.get_child(name, module_class) | ||||
|                 nodes_by_module_stack = new_stack | ||||
|                 nodes_by_module_stack.add(node) | ||||
|  | ||||
|     assert nodes_by_module_stack_root is not None, "empty node in the graph" | ||||
|     return nodes_by_module_stack_root | ||||
|  | ||||
|  | ||||
| def decode_module(module_bucket_plans: list[str]) -> list[list[str] | str]: | ||||
|     """ | ||||
|     Convert abbreviated FQNs to the actual FQNs. | ||||
|     Currently, we support the decoding of these abbreviations: | ||||
|     (1) layers.[0-2] -> [layers.0], [layers.1], [layers.2] | ||||
|         (layers are split three separate buckets) | ||||
|     (2) norm+output -> [norm, output] | ||||
|         (norm and output are in one bucket) | ||||
|     """ | ||||
|     full_plan: list[list[str] | str] = [] | ||||
|     for module_name in module_bucket_plans: | ||||
|         if "+" in module_name: | ||||
|             full_plan.append(module_name.split("+")) | ||||
|             continue | ||||
|         match = re.search(r"\[(\d+)-(\d+)\]", module_name) | ||||
|         if not match: | ||||
|             full_plan.append(module_name) | ||||
|         else: | ||||
|             start, end = map(int, match.groups()) | ||||
|             prefix = module_name[: match.start()] | ||||
|             suffix = module_name[match.end() :] | ||||
|             full_plan.extend([f"{prefix}{i}{suffix}" for i in range(start, end + 1)]) | ||||
|     return full_plan | ||||
|  | ||||
|  | ||||
| def get_subgraph_by_path( | ||||
|     graph_view: Container, paths: Union[str, list[str]] | ||||
| ) -> list[fx.Node]: | ||||
|     """ | ||||
|     Get subgraph by path(s). | ||||
|     Args: | ||||
|         graph_view (object): Root graph view object. | ||||
|         paths (str or list of str): Path(s) to subgraph. | ||||
|     Returns: | ||||
|         object: Subgraph object or node. | ||||
|     """ | ||||
|  | ||||
|     def get_node_by_path(node: Container, path: str) -> Container: | ||||
|         for p in path.split("."): | ||||
|             if p in node.children: | ||||
|                 node = node.children[p] | ||||
|             else: | ||||
|                 return Container("", object) | ||||
|         return node | ||||
|  | ||||
|     if isinstance(paths, list): | ||||
|         nodes = list( | ||||
|             itertools.chain.from_iterable( | ||||
|                 get_node_by_path(graph_view, p).data for p in paths | ||||
|             ) | ||||
|         ) | ||||
|         return nodes | ||||
|     else: | ||||
|         node = get_node_by_path(graph_view, paths) | ||||
|         return node.data | ||||
|  | ||||
|  | ||||
| class ManualOverlapPreservingBucketer(OverlapPreservingBucketer): | ||||
|     """ | ||||
|     Buckets collective operations based on user specifications. | ||||
|     The actual bucket happens in bucket_collectives, where all-gathers/reduce-scatters in | ||||
|         `nodes` will be buckted one single all-gather/reduce-scatter. | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, node_users: dict[fx.Node, OrderedSet[fx.Node]], *args: Any, **kwargs: Any | ||||
|     ): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.node_users = node_users | ||||
|         self.wait_to_node_map: dict[fx.Node, fx.Node] = defaultdict() | ||||
|  | ||||
|     def _check_recursive_dep( | ||||
|         self, | ||||
|         node: fx.Node, | ||||
|         target_op: str, | ||||
|         dep_dict: dict[torch.fx.Node, OrderedSet[torch.fx.Node]], | ||||
|     ) -> bool: | ||||
|         """ | ||||
|         Check if the node is directly used for fetch parameters/gradients | ||||
|  | ||||
|         TODO (ruisizhang123): currently, we assume the node only pre-fetch/update one parameter/gradient | ||||
|             We should handle multiple parameters/gradients update case by checking if there are non closure | ||||
|             computes along the path from primal/output to coll_node | ||||
|         """ | ||||
|         deps: OrderedSet[fx.Node] = dep_dict[node] | ||||
|         seen_target_op = 0 | ||||
|         for d in deps: | ||||
|             if d.op == target_op: | ||||
|                 seen_target_op += 1 | ||||
|  | ||||
|         return seen_target_op == 1 | ||||
|  | ||||
|     def _bucket_group(self, coll_nodes: list[fx.Node]) -> None: | ||||
|         assert len(coll_nodes) > 0, "bucketed coll_nodes should have nonzero node" | ||||
|  | ||||
|         waits = [self.collective_info[n].wait_node for n in coll_nodes] | ||||
|         # Use earliest wait insertion point | ||||
|         first_wait = min(waits, key=lambda w: self.node_idx[w]) | ||||
|         # Find insertion location | ||||
|         first = coll_nodes[0] | ||||
|         next_node = first | ||||
|         while next_node in coll_nodes: | ||||
|             next_node = next_node.next | ||||
|  | ||||
|         if is_all_gather(first): | ||||
|             new_nodes, replacements = merge_all_gather_bucket( | ||||
|                 self.graph, | ||||
|                 coll_nodes, | ||||
|                 wait_insertion_point=first_wait, | ||||
|                 insert_before=next_node, | ||||
|                 mode="custom_ops", | ||||
|             ) | ||||
|         elif is_reduce_scatter(first): | ||||
|             new_nodes, replacements = merge_reduce_scatter_bucket( | ||||
|                 self.graph, | ||||
|                 coll_nodes, | ||||
|                 wait_insertion_point=first_wait, | ||||
|                 insert_before=next_node, | ||||
|                 mode="custom_ops", | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError( | ||||
|                 "bucket non all_gather/reduce_scatter node is not supported" | ||||
|             ) | ||||
|  | ||||
|         # Identify the new wait and start | ||||
|         new_waits = [n for n in new_nodes if is_wait_tensor(n)] | ||||
|         assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}" | ||||
|         new_wait = new_waits[0] | ||||
|         new_start = new_wait.args[0] | ||||
|         assert isinstance(new_start, fx.Node) | ||||
|  | ||||
|         node_type = ( | ||||
|             "bucketed_all_gather" if is_all_gather(first) else "bucketed_reduce_scatter" | ||||
|         ) | ||||
|         for n in new_nodes: | ||||
|             n.meta["nn_module_stack"] = coll_nodes[0].meta.get("nn_module_stack", "") | ||||
|             n.meta["fwd_nn_module_stack"] = coll_nodes[0].meta.get( | ||||
|                 "fwd_nn_module_stack", "" | ||||
|             ) | ||||
|             if n == new_wait: | ||||
|                 node_type = node_type + "_wait" | ||||
|             n.meta["manual_bucket_node_type"] = node_type | ||||
|             if "wait" in node_type: | ||||
|                 self.wait_to_node_map[n] = new_wait | ||||
|  | ||||
|     def manual_bucket_collectives(self, nodes: list[fx.Node]) -> None: | ||||
|         """ | ||||
|         Bucket all all-gather/reduce-scatter nodes from nodes into one all-gather/reduce-scatter. | ||||
|         """ | ||||
|         # Filter out valid collectives | ||||
|         collectives = [n for n in nodes if n in self.collective_info] | ||||
|         if collectives == []: | ||||
|             return | ||||
|         grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet) | ||||
|         for node in collectives: | ||||
|             key = bucket_key(node) | ||||
|             if not (is_all_gather(node) or is_reduce_scatter(node)): | ||||
|                 continue | ||||
|             # We only want to bucket all-gather/reduce-scatter that | ||||
|             # 1. all_gather that have ancestors dependent only on input placeholder(parameters) | ||||
|             # 2. reduce scatter that the wait user node is returned as output(gradients) | ||||
|             if is_all_gather(node) and not self._check_recursive_dep( | ||||
|                 node, "placeholder", self.node_ancestors | ||||
|             ): | ||||
|                 continue | ||||
|             if is_reduce_scatter(node) and not self._check_recursive_dep( | ||||
|                 self.collective_info[node].wait_node, "output", self.node_users | ||||
|             ): | ||||
|                 continue | ||||
|             if key is not None: | ||||
|                 grouped_collectives[key].add(node) | ||||
|  | ||||
|         for key, nodes in grouped_collectives.items(): | ||||
|             self._bucket_group(list(nodes)) | ||||
|  | ||||
|  | ||||
| class ManualOverlapScheduler(OverlapScheduler): | ||||
|     """ | ||||
|     Scheduler that manual buckets and reorders collective nodes based on module_bucket_plans | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, gm: fx.GraphModule, module_bucket_plans: list[list[str] | str]): | ||||
|         super().__init__( | ||||
|             gm, | ||||
|             max_in_flight_gb=0.0, | ||||
|             max_compute_pre_fetch=0, | ||||
|             collective_bucketing=True, | ||||
|             insert_overlap_deps=True, | ||||
|             compute_overlap_multipler=0.0, | ||||
|             max_coll_distance=0, | ||||
|             custom_runtime_estimation=None, | ||||
|         ) | ||||
|         self.module_bucket_plans = module_bucket_plans | ||||
|         self.nodes_in_subgraph: list[list[fx.Node]] = [] | ||||
|  | ||||
|         self.node_users: dict[fx.Node, OrderedSet[fx.Node]] = self._collect_node_users() | ||||
|         self.bucketer = ManualOverlapPreservingBucketer( | ||||
|             graph=self.graph, | ||||
|             collective_info=self.collective_info, | ||||
|             node_ancestors=self.node_ancestors, | ||||
|             node_users=self.node_users, | ||||
|             scheduled=OrderedSet(self.graph.nodes), | ||||
|         ) | ||||
|  | ||||
|     def _identify_collectives(self) -> None: | ||||
|         """Identify all collective operations.""" | ||||
|         for node in self.nodes: | ||||
|             if is_wait_tensor(node): | ||||
|                 start = node.args[0] | ||||
|                 info = CollectiveInfo( | ||||
|                     start_node=start, | ||||
|                     wait_node=node, | ||||
|                     size_bytes=0, | ||||
|                     estimated_time_ms=0, | ||||
|                     exposed_time_ms=0, | ||||
|                 ) | ||||
|                 self.collective_info[start] = info | ||||
|                 self.wait_to_start[node] = start | ||||
|                 self.unscheduled_collectives.add(start) | ||||
|  | ||||
|     def run(self) -> torch.fx.GraphModule: | ||||
|         """Entry point to run the manual bucket algorithm""" | ||||
|         # Bucket collectives in each bucket_module | ||||
|         self._manual_bucket_collectives() | ||||
|  | ||||
|         # Reorder collectives with last/next bucket_module | ||||
|         self._manual_reorder_graph() | ||||
|  | ||||
|         return self.gm | ||||
|  | ||||
|     def _manual_reorder_graph(self) -> None: | ||||
|         """ | ||||
|         Reorder nodes in the FX graph to enforce manual overlap dependencies. | ||||
|  | ||||
|         Enforce: | ||||
|         - all_gather_start_i depends on all_gather_wait_(i-1) | ||||
|         - reduce_scatter_wait_i must happen before reduce_scatter_start_(i+1) | ||||
|         """ | ||||
|         delayed_rs_nodes: list[fx.Node] = [] | ||||
|         overlap_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) | ||||
|  | ||||
|         # schedule reduce scatter normally in self._schedule | ||||
|         while self.ready: | ||||
|             _, node = heapq.heappop(self.ready) | ||||
|             node_type = node.meta.get("manual_bucket_node_type", "") | ||||
|  | ||||
|             if node in self.scheduled: | ||||
|                 continue | ||||
|  | ||||
|             if node_type == "bucketed_reduce_scatter": | ||||
|                 # Ensure all delayed waits execute before this reduce_scatter | ||||
|                 for delayed in delayed_rs_nodes: | ||||
|                     self._schedule(delayed) | ||||
|                     overlap_deps[delayed].add(node) | ||||
|                 delayed_rs_nodes.clear() | ||||
|  | ||||
|             elif node_type == "bucketed_reduce_scatter_wait": | ||||
|                 # Defer until next reduce_scatter | ||||
|                 delayed_rs_nodes.append(node) | ||||
|                 continue | ||||
|             self._schedule(node) | ||||
|  | ||||
|         for delayed in delayed_rs_nodes: | ||||
|             self._schedule(delayed) | ||||
|  | ||||
|         self.scheduled = OrderedSet(reversed(list(self.scheduled))) | ||||
|         picked_ag: list[fx.Node] = [] | ||||
|         last_compute: Optional[fx.Node] = None | ||||
|  | ||||
|         for node in self.scheduled: | ||||
|             node_type = node.meta.get("manual_bucket_node_type", "") | ||||
|             if node_type == "bucketed_all_gather": | ||||
|                 picked_ag.append(node) | ||||
|                 continue | ||||
|  | ||||
|             if node_type == "bucketed_all_gather_wait": | ||||
|                 # Connect corresponding all_gather_wait -> all_gather edges | ||||
|                 if picked_ag: | ||||
|                     for ag in picked_ag: | ||||
|                         overlap_deps[self.bucketer.wait_to_node_map[node]].add(ag) | ||||
|                 picked_ag.clear() | ||||
|             if is_compute_node(node): | ||||
|                 last_compute = node | ||||
|  | ||||
|         if last_compute is not None and not bool( | ||||
|             OrderedSet(picked_ag) & OrderedSet(self.node_ancestors[last_compute]) | ||||
|         ): | ||||
|             for ag in picked_ag: | ||||
|                 overlap_deps[last_compute].add(ag) | ||||
|  | ||||
|         _stable_topological_sort(self.graph, overlap_deps) | ||||
|         self.graph.lint() | ||||
|  | ||||
|     def _manual_bucket_collectives(self) -> None: | ||||
|         """Bucket nodes in each module_bucket from module_bucket_plans.""" | ||||
|         self._obtain_nodes_in_subgraph() | ||||
|         for i, nodes in enumerate(self.nodes_in_subgraph): | ||||
|             self.bucketer.manual_bucket_collectives(nodes=nodes) | ||||
|  | ||||
|         _stable_topological_sort(self.graph, {}) | ||||
|         self.graph.lint() | ||||
|         self.nodes = list(self.graph.nodes) | ||||
|         self.in_degree = Counter(user for node in self.nodes for user in node.users) | ||||
|  | ||||
|     def _collect_node_users(self) -> dict[fx.Node, OrderedSet[fx.Node]]: | ||||
|         """Collect all users for each node.""" | ||||
|         node_users: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) | ||||
|         for node in self.nodes: | ||||
|             for output_node in list(node.users.keys()): | ||||
|                 node_users[node].add(output_node) | ||||
|                 node_users[node] |= node_users[output_node] | ||||
|         return node_users | ||||
|  | ||||
|     def _schedule(self, node: fx.Node) -> None: | ||||
|         """Schedule a node.""" | ||||
|         assert node not in self.scheduled | ||||
|         assert all(n in self.scheduled for n in node.all_input_nodes) | ||||
|         self.scheduled.add(node) | ||||
|         for user in node.users: | ||||
|             self.in_degree[user] -= 1 | ||||
|             if self.in_degree[user] == 0: | ||||
|                 heapq.heappush(self.ready, ((), user)) | ||||
|  | ||||
|     def _obtain_nodes_in_subgraph(self) -> None: | ||||
|         """ | ||||
|         Obtain nodes in each subgraph from module_bucket_plans | ||||
|         """ | ||||
|         graph_view: Container = make_graph_view(self.graph) | ||||
|         for module in self.module_bucket_plans: | ||||
|             subgraph_view = get_subgraph_by_path(graph_view.children["Model"], module) | ||||
|             self.nodes_in_subgraph.append(subgraph_view) | ||||
|  | ||||
|  | ||||
| def manual_overlap_bucketing( | ||||
|     gm: torch.fx.GraphModule, | ||||
|     module_bucket_plans: list[str], | ||||
| ) -> torch.fx.GraphModule: | ||||
|     """Schedule nodes based on user specifications in module_bucket_plans | ||||
|     The manual overlapping consists of two steps: | ||||
|     Step 1: bucket all-gather/reduce-scatter in each module in module_bucket_plans | ||||
|     Step 2: reorder all-gather to overlap with last module_bucket & | ||||
|         reorder reduce-scatter to overlap with next module_bucket | ||||
|     TODO(ruisizhang123): allow users to explicitly specify which | ||||
|         module_bucket they want to overlap. | ||||
|  | ||||
|     Args: | ||||
|         gm: input graph module to optimize. | ||||
|         module_bucket_plans: user specified FQNs | ||||
|     """ | ||||
|     # decode abbreviated FQNs to actual FQNs | ||||
|     module_bucket_plans = decode_module(module_bucket_plans) | ||||
|     overlapped_gm = ManualOverlapScheduler(gm, module_bucket_plans).run() | ||||
|     overlapped_gm.recompile() | ||||
|     return overlapped_gm | ||||
| @ -17,7 +17,6 @@ import time | ||||
| from collections.abc import Sequence | ||||
| from concurrent.futures import as_completed, ThreadPoolExecutor | ||||
| from io import StringIO | ||||
| from pathlib import Path | ||||
| from types import ModuleType | ||||
| from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union | ||||
| from typing_extensions import Self | ||||
| @ -2105,11 +2104,6 @@ class TritonTemplate(KernelTemplate): | ||||
|                 "matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0), | ||||
|                 "waves_per_eu": kwargs.get("waves_per_eu", 0), | ||||
|                 "kpack": kwargs.get("kpack", 2), | ||||
|                 **{ | ||||
|                     k: kwargs[k] | ||||
|                     for k in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS | ||||
|                     if k in kwargs | ||||
|                 }, | ||||
|             }, | ||||
|             mutated_inputs=mutated_inputs, | ||||
|             workspace_arg=workspace_arg, | ||||
| @ -2403,17 +2397,6 @@ def get_mm_log_filename() -> Optional[str]: | ||||
|     return mm_file_name | ||||
|  | ||||
|  | ||||
| @functools.cache | ||||
| def get_flex_attention_log_filename() -> Optional[str]: | ||||
|     flex_attention_file_name = os.environ.get( | ||||
|         "TORCHINDUCTOR_FLEX_ATTENTION_LOGGING_FILE", None | ||||
|     ) | ||||
|     if not flex_attention_file_name: | ||||
|         return None | ||||
|  | ||||
|     return str(Path(flex_attention_file_name).with_suffix(".json")) | ||||
|  | ||||
|  | ||||
| def append_to_log(filename, data): | ||||
|     lock_file = filename.replace(".json", ".lock") | ||||
|     lock = FileLock(lock_file) | ||||
| @ -2624,25 +2607,6 @@ class AlgorithmSelectorCache(PersistentCache): | ||||
|     doesn't depend on the output layout. | ||||
|     """ | ||||
|  | ||||
|     FLEX_ATTENTION_TUNABLE_KEYS = tuple( | ||||
|         dict.fromkeys( | ||||
|             [ | ||||
|                 "num_warps", | ||||
|                 "num_stages", | ||||
|                 "BLOCK_M", | ||||
|                 "BLOCK_N", | ||||
|                 "BLOCK_M1", | ||||
|                 "BLOCK_N1", | ||||
|                 "BLOCK_M2", | ||||
|                 "BLOCK_N2", | ||||
|                 "USE_TMA", | ||||
|                 "kpack", | ||||
|                 "matrix_instr_nonkdim", | ||||
|                 "waves_per_eu", | ||||
|             ] | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     def __init__(self, *args, **kwargs) -> None: | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
| @ -3576,73 +3540,6 @@ class AlgorithmSelectorCache(PersistentCache): | ||||
|         ) | ||||
|         return pruned_choices | ||||
|  | ||||
|     @staticmethod | ||||
|     def get_flex_attention_choice_info( | ||||
|         choice: ChoiceCaller, timings: dict[ChoiceCaller, float] | ||||
|     ) -> dict[str, Any]: | ||||
|         if isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller): | ||||
|             return {"type": "extern", "time": timings[choice]} | ||||
|  | ||||
|         assert isinstance(choice, torch._inductor.select_algorithm.TritonTemplateCaller) | ||||
|  | ||||
|         info = choice.info_dict() | ||||
|         result = { | ||||
|             "type": "triton", | ||||
|             "time": timings[choice], | ||||
|         } | ||||
|  | ||||
|         for key in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS: | ||||
|             if key in info: | ||||
|                 result[key] = info[key] | ||||
|  | ||||
|         return result | ||||
|  | ||||
|     @staticmethod | ||||
|     def maybe_log_flex_attention_results( | ||||
|         name: str, input_nodes: list[ir.IRNode], timings: dict[ChoiceCaller, float] | ||||
|     ) -> None: | ||||
|         flex_attention_filename = get_flex_attention_log_filename() | ||||
|         if not flex_attention_filename or "flex_attention" not in name: | ||||
|             return | ||||
|  | ||||
|         if len(input_nodes) < 3: | ||||
|             return | ||||
|  | ||||
|         query_size = input_nodes[0].get_size() | ||||
|         key_size = input_nodes[1].get_size() | ||||
|         value_size = input_nodes[2].get_size() | ||||
|  | ||||
|         B = query_size[0] | ||||
|         Hq = query_size[1] | ||||
|         seq_len_q = query_size[2] | ||||
|         qk_head_dim = query_size[3] | ||||
|         Hkv = key_size[1] | ||||
|         seq_len_kv = key_size[2] | ||||
|         v_head_dim = value_size[3] | ||||
|  | ||||
|         kernel_type = "backward" if "backward" in name else "forward" | ||||
|         dims_key = str( | ||||
|             ( | ||||
|                 kernel_type, | ||||
|                 B, | ||||
|                 Hq, | ||||
|                 Hkv, | ||||
|                 seq_len_q, | ||||
|                 seq_len_kv, | ||||
|                 qk_head_dim, | ||||
|                 v_head_dim, | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         sorted_choices = sorted(timings, key=timings.__getitem__) | ||||
|         out_dict = { | ||||
|             dims_key: [ | ||||
|                 AlgorithmSelectorCache.get_flex_attention_choice_info(choice, timings) | ||||
|                 for choice in sorted_choices | ||||
|             ] | ||||
|         } | ||||
|         append_to_log(flex_attention_filename, out_dict) | ||||
|  | ||||
|     @staticmethod | ||||
|     def log_results( | ||||
|         name: str, | ||||
| @ -3653,7 +3550,6 @@ class AlgorithmSelectorCache(PersistentCache): | ||||
|         prescreening_elapse: Optional[float] = None, | ||||
|         hint_override: Optional[int] = None, | ||||
|     ): | ||||
|         """Log the autotuning results, currently only handles mm and flex""" | ||||
|         V.debug.log_autotuning_results( | ||||
|             name, input_nodes, timings, elapse, precompile_elapse | ||||
|         ) | ||||
| @ -3722,10 +3618,6 @@ class AlgorithmSelectorCache(PersistentCache): | ||||
|  | ||||
|             append_to_log(mm_filename, out_dict) | ||||
|  | ||||
|         AlgorithmSelectorCache.maybe_log_flex_attention_results( | ||||
|             name, input_nodes, timings | ||||
|         ) | ||||
|  | ||||
|         best_time = timings[best] | ||||
|         sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") | ||||
|         sys.stderr.write(f"strides: {strides}\n") | ||||
|  | ||||
		Reference in New Issue
	
	Block a user
	