mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
# Summary This adds a few more render functions available to template writers, specifically get_output and modification. The reasons why are more clear in the next PR in this stack. <img width="1645" height="364" alt="Screenshot 2025-08-21 at 1 48 50 PM" src="https://github.com/user-attachments/assets/2d508fda-4273-43ef-9edf-086e592e9249" /> Majority of the new cod is around the OpOverrides for CuTe DSL. It is alot to test and most of the actual testing I have been doing is via score_mods to the flash_attention at the next layer of this stack. A bunch of score mods that me and Claude came up with , that exercise the actual ops. ``` Py def causal_mask(score, b, h, q_idx, kv_idx): """Causal attention mask.""" return torch.where(q_idx >= kv_idx, score, float("-inf")) def relative_bias(score, b, h, token_q, token_kv): """Relative position bias.""" return score + torch.abs(token_q - token_kv) def relative_bias_v2(score, b, h, token_q, token_kv): """Relative position bias with factor of 2.""" return score + 2 * torch.abs(token_q - token_kv) def times_two(score, b, h, q_idx, kv_idx): """Simple score modification that doubles the score.""" return score * 2 def alibi_bias(score, b, h, q_idx, kv_idx): """ALiBi (Attention with Linear Biases) - used in some modern models.""" # Different slopes for different heads slope = 2 ** (-8 * (h + 1) / 8) # Simplified version return score - slope * torch.abs(q_idx - kv_idx) def sliding_window(score, b, h, q_idx, kv_idx, window_size=256): """Sliding window attention - only attend to nearby tokens.""" return torch.where( torch.abs(q_idx - kv_idx) <= window_size, score, float("-inf") ) def block_diagonal(score, b, h, q_idx, kv_idx, block_size=64): """Block diagonal attention pattern.""" q_block = q_idx // block_size kv_block = kv_idx // block_size return torch.where(q_block == kv_block, score, float("-inf")) def additive_bias(score, b, h, q_idx, kv_idx): """Test simple addition with position-based bias.""" return score + (q_idx + kv_idx) * 0.01 def multiplicative_decay(score, b, h, q_idx, kv_idx): """Test multiplication with distance-based decay.""" distance = torch.abs(q_idx - kv_idx) return score * torch.exp(-0.1 * distance) def sine_wave_bias(score, b, h, q_idx, kv_idx): """Test trigonometric functions.""" return score + 0.1 * torch.sin(2 * math.pi * (q_idx - kv_idx) / 64) def log_distance_penalty(score, b, h, q_idx, kv_idx): """Test logarithmic operations.""" distance = torch.abs(q_idx - kv_idx).float() return score - torch.log(1 + distance) def alternating_mask(score, b, h, q_idx, kv_idx): """Test with alternating pattern - good for branch prediction.""" return torch.where((q_idx + kv_idx) % 2 == 0, score, float("-inf")) def head_specific_pattern(score, b, h, q_idx, kv_idx): """Different behavior per attention head.""" even_head = h % 2 == 0 causal = q_idx >= kv_idx return torch.where(even_head & causal, score, float("-inf")) def sparse_strided(score, b, h, q_idx, kv_idx, stride=4): """Sparse attention with strided pattern.""" return torch.where( (kv_idx % stride == 0) | (q_idx == kv_idx), score, float("-inf") ) def causal_with_global(score, b, h, q_idx, kv_idx): """Causal mask but first few tokens are globally attended.""" is_causal = q_idx >= kv_idx is_global = kv_idx < 4 return torch.where(is_causal | is_global, score, float("-inf")) def dilated_attention(score, b, h, q_idx, kv_idx, dilation_rate=2): """Dilated attention pattern - exponentially increasing gaps.""" distance = torch.abs(q_idx - kv_idx) is_attended = (distance == 0) | ((distance > 0) & ((distance & (distance - 1)) == 0)) return torch.where(is_attended, score, float("-inf")) ``` Example outputs: ``` [Test Suite] Config: batch=4, heads=32, seq_q=8192, seq_kv=8192, dim=128 [Test 1: none] [No score_mod, flash='enabled'] Found flash_attncute: True [No score_mod, flash='disabled'] Found flash_attncute: False ✓ Outputs match between flash enabled/disabled ✓ Output matches eager SDPA (rtol=0.001, atol=0.001) [Test 2: causal] [With score_mod, flash='enabled'] Found flash_attncute: True [With score_mod, flash='disabled'] Found flash_attncute: False ✗ Outputs differ between flash modes: Tensor-likes are not close! Mismatched elements: 17879 / 134217728 (0.0%) Greatest absolute difference: 0.0078125 at index (0, 15, 15, 60) (up to 0.001 allowed) Greatest relative difference: 2.5 at index (3, 22, 153, 126) (up to 0.001 allowed) [Test 3: rel_bias] [With score_mod, flash='enabled'] Found flash_attncute: True [With score_mod, flash='disabled'] Found flash_attncute: False ✗ Outputs differ between flash modes: Tensor-likes are not close! Mismatched elements: 12836 / 134217728 (0.0%) Greatest absolute difference: 0.015625 at index (0, 3, 2775, 84) (up to 0.001 allowed) Greatest relative difference: 11.8125 at index (3, 28, 4095, 76) (up to 0.001 allowed) [Test 4: rel_bias_v2] ``` This is bfloat16 and there are no major differences. The list of pointwise ops here isn't exhaustive but it is fairly covering Pull Request resolved: https://github.com/pytorch/pytorch/pull/161117 Approved by: https://github.com/mlazos
485 lines
16 KiB
Python
485 lines
16 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from expecttest import assert_expected_inline
|
|
|
|
import torch
|
|
from torch._inductor.test_case import TestCase
|
|
from torch._inductor.virtualized import V
|
|
from torch.testing._internal.inductor_utils import MockGraphHandler
|
|
|
|
|
|
try:
|
|
import cutlass # noqa: F401
|
|
import cutlass.cute as cute # noqa: F401
|
|
|
|
HAS_CUTLASS = True
|
|
except ImportError:
|
|
HAS_CUTLASS = False
|
|
|
|
if HAS_CUTLASS:
|
|
from torch._inductor.codegen.cutedsl.cutedsl_kernel import CuteDSLTemplateKernel
|
|
from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate
|
|
from torch._inductor.select_algorithm import PartialRender
|
|
|
|
|
|
CUTEDSL_ADD_TEMPLATE = r"""
|
|
{{gen_defines()}}
|
|
|
|
@cute.kernel
|
|
def {{kernel_name}}_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor):
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
bidx, _, _ = cute.arch.block_idx()
|
|
bdim, _, _ = cute.arch.block_dim()
|
|
|
|
thread_idx = bidx * bdim + tidx
|
|
m, n = gA.shape
|
|
|
|
if thread_idx < m * n:
|
|
mi = thread_idx // n
|
|
ni = thread_idx % n
|
|
|
|
if mi < m and ni < n:
|
|
gC[mi, ni] = gA[mi, ni] + gB[mi, ni]
|
|
|
|
@cute.jit
|
|
def {{kernel_name}}_jit(mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, stream):
|
|
{{gen_defines()}}
|
|
m, n = mA.shape
|
|
total_threads = m * n
|
|
num_blocks = (total_threads + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
|
|
|
|
kernel = {{kernel_name}}_kernel(mA, mB, mC)
|
|
kernel.launch(
|
|
grid=[num_blocks, 1, 1],
|
|
block=[THREADS_PER_BLOCK, 1, 1],
|
|
stream=stream
|
|
)
|
|
|
|
{{def_kernel("input_a", "input_b")}}
|
|
cute_a = from_dlpack(input_a)
|
|
cute_b = from_dlpack(input_b)
|
|
cute_c = from_dlpack({{get_output()}})
|
|
|
|
{{kernel_name}}_jit(cute_a, cute_b, cute_c, cuda.CUstream(stream))
|
|
return {{get_output()}}
|
|
"""
|
|
|
|
|
|
@unittest.skipUnless(HAS_CUTLASS, "requires cutlass")
|
|
class TestCuteDSLTemplate(TestCase):
|
|
"""Test cases for CuteDSL template functionality."""
|
|
|
|
def test_gen_imports(self):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
imports = kernel.gen_imports()
|
|
|
|
self.assertIn("import torch", imports)
|
|
self.assertIn("import cutlass", imports)
|
|
self.assertIn("import cutlass.cute as cute", imports)
|
|
self.assertIn("from cutlass.cute.runtime import from_dlpack", imports)
|
|
self.assertIsInstance(imports, str)
|
|
|
|
lines = imports.strip().split("\n")
|
|
self.assertEqual(len(lines), 7)
|
|
|
|
def test_render_includes_imports(self):
|
|
template_source = """@cute.kernel
|
|
def {{kernel_name}}_kernel():
|
|
pass
|
|
|
|
{{def_kernel("input", "output")}}
|
|
return output"""
|
|
|
|
mock_template = MagicMock()
|
|
mock_template.render = MagicMock(return_value=template_source)
|
|
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
result = kernel.render(mock_template)
|
|
self.assertIsInstance(result, PartialRender)
|
|
|
|
rendered_code = result._code
|
|
|
|
# The imports might have leading whitespace, so strip it
|
|
rendered_code_stripped = rendered_code.lstrip()
|
|
|
|
self.assertTrue(
|
|
rendered_code_stripped.startswith("import torch"),
|
|
f"Code should start with 'import torch', got: {rendered_code_stripped[:50]}",
|
|
)
|
|
self.assertIn("import cutlass", rendered_code)
|
|
self.assertIn("import cutlass.cute as cute", rendered_code)
|
|
self.assertIn("from cutlass.cute.runtime import from_dlpack", rendered_code)
|
|
self.assertIn("@cute.kernel", rendered_code)
|
|
|
|
def test_template_env_contains_hooks(self):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
captured_env = {}
|
|
|
|
def mock_render(**kwargs):
|
|
captured_env.update(kwargs)
|
|
return "rendered"
|
|
|
|
mock_template = MagicMock()
|
|
mock_template.render = mock_render
|
|
|
|
kernel.render(mock_template)
|
|
|
|
self.assertIn("def_kernel", captured_env)
|
|
self.assertIn("kernel_name", captured_env)
|
|
self.assertTrue(callable(captured_env["def_kernel"]))
|
|
|
|
def test_multiple_templates_unique_names(self):
|
|
# Clean registry first
|
|
test_name = f"unique_test_{id(self)}"
|
|
if test_name in CuteDSLTemplate.all_templates:
|
|
del CuteDSLTemplate.all_templates[test_name]
|
|
|
|
_ = CuteDSLTemplate(
|
|
name=test_name,
|
|
source="template1",
|
|
)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
_ = CuteDSLTemplate(
|
|
name=test_name,
|
|
source="template2",
|
|
)
|
|
|
|
def test_indented_buffer_usage(self):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
imports = kernel.gen_imports()
|
|
|
|
lines = imports.strip().split("\n")
|
|
for line in lines:
|
|
if line:
|
|
self.assertFalse(
|
|
line.startswith(" "), f"Line should not be indented: '{line}'"
|
|
)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
|
def test_cutedsl_add_e2e(self):
|
|
"""End-to-end test with CuteDSL template including code generation verification."""
|
|
from torch._inductor.ir import TensorBox
|
|
from torch._inductor.lowering import lowerings
|
|
from torch._inductor.utils import run_and_get_code
|
|
|
|
template = CuteDSLTemplate(
|
|
name="test_add_e2e",
|
|
source=CUTEDSL_ADD_TEMPLATE,
|
|
)
|
|
|
|
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
|
|
choices = []
|
|
error = template.maybe_append_choice(
|
|
choices,
|
|
input_nodes=[a, b],
|
|
layout=a.get_layout(),
|
|
THREADS_PER_BLOCK=256,
|
|
)
|
|
|
|
if error or not choices:
|
|
default_lowering = lowerings[torch.ops.aten.add.Tensor]
|
|
return default_lowering(a, b)
|
|
|
|
# Use the single choice directly (no autotuning)
|
|
return choices[0].output_node()
|
|
|
|
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
|
|
# Test function
|
|
def test_add(x, y):
|
|
return x + y
|
|
|
|
device = "cuda"
|
|
x = torch.randn(128, 4, device=device, dtype=torch.float32)
|
|
y = torch.randn(128, 4, device=device, dtype=torch.float32)
|
|
|
|
# Compile and get generated code
|
|
compiled_fn = torch.compile(test_add, backend="inductor")
|
|
result, (code,) = run_and_get_code(compiled_fn, x, y)
|
|
|
|
# Verify CuteDSL code is present
|
|
self.assertIn(
|
|
"cute", code.lower(), "CuteDSL code should be in generated code"
|
|
)
|
|
# Verify parameter generation worked
|
|
self.assertIn(
|
|
"THREADS_PER_BLOCK", code, "Parameter should be in generated code"
|
|
)
|
|
|
|
# Verify correctness
|
|
expected = x + y
|
|
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
|
|
def test_cutedsl_add_e2e_autotune(self):
|
|
"""E2E test with multiple CuteDSL template variants for autotuning."""
|
|
from torch._inductor.ir import TensorBox
|
|
from torch._inductor.lowering import lowerings
|
|
from torch._inductor.select_algorithm import autotune_select_algorithm
|
|
|
|
template = CuteDSLTemplate(
|
|
name="test_add_autotune",
|
|
source=CUTEDSL_ADD_TEMPLATE,
|
|
)
|
|
|
|
def cutedsl_add_lowering(a: TensorBox, b: TensorBox) -> TensorBox:
|
|
choices = []
|
|
|
|
# Add multiple variants with different thread counts for autotuning
|
|
thread_variants = [128, 256, 512]
|
|
for threads in thread_variants:
|
|
error = template.maybe_append_choice(
|
|
choices,
|
|
input_nodes=[a, b],
|
|
layout=a.get_layout(),
|
|
THREADS_PER_BLOCK=threads,
|
|
)
|
|
if error:
|
|
# Skip this variant if it fails
|
|
continue
|
|
|
|
if not choices:
|
|
default_lowering = lowerings[torch.ops.aten.add.Tensor]
|
|
return default_lowering(a, b)
|
|
|
|
# Use autotuning to select the best variant
|
|
return autotune_select_algorithm(
|
|
"cutedsl_add_autotune",
|
|
choices,
|
|
[a, b],
|
|
a.get_layout(),
|
|
)
|
|
|
|
with patch.dict(lowerings, {torch.ops.aten.add.Tensor: cutedsl_add_lowering}):
|
|
# Test function
|
|
def test_add(x, y):
|
|
return x + y
|
|
|
|
device = "cuda"
|
|
x = torch.randn(128, 128, device=device, dtype=torch.float32)
|
|
y = torch.randn(128, 128, device=device, dtype=torch.float32)
|
|
|
|
# Compile and run
|
|
compiled_fn = torch.compile(test_add, backend="inductor")
|
|
result = compiled_fn(x, y)
|
|
|
|
# Verify correctness
|
|
expected = x + y
|
|
self.assertTrue(torch.allclose(result, expected, atol=1e-5))
|
|
|
|
def test_gen_defines(self):
|
|
"""Test that gen_defines correctly generates CuteDSL parameter definitions."""
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_kernel",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
# Test integer parameters
|
|
params = kernel.gen_defines(
|
|
THREADS_PER_BLOCK=256,
|
|
BLOCK_SIZE=128,
|
|
ENABLE_FEATURE=True,
|
|
)
|
|
|
|
assert_expected_inline(
|
|
params,
|
|
"""\
|
|
THREADS_PER_BLOCK: cutlass.Constexpr = 256
|
|
BLOCK_SIZE: cutlass.Constexpr = 128
|
|
ENABLE_FEATURE: cutlass.Constexpr = True
|
|
""",
|
|
)
|
|
|
|
params_float = kernel.gen_defines(SCALE_FACTOR=1.5)
|
|
assert_expected_inline(
|
|
params_float,
|
|
"""\
|
|
SCALE_FACTOR: cutlass.Constexpr = 1.5
|
|
""",
|
|
)
|
|
|
|
def test_template_aliasing(self):
|
|
"""Test that template variables are correctly aliased to function arguments."""
|
|
from torch._inductor.ir import Buffer
|
|
|
|
mock_input1 = MagicMock(spec=Buffer)
|
|
mock_input1.get_name.return_value = "buf_input1"
|
|
|
|
mock_input2 = MagicMock(spec=Buffer)
|
|
mock_input2.get_name.return_value = "buf_input2"
|
|
|
|
mock_output = MagicMock(spec=Buffer)
|
|
mock_output.get_name.return_value = "buf_output"
|
|
|
|
mock_graph = MockGraphHandler()
|
|
with V.set_graph_handler(mock_graph):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_aliasing",
|
|
input_nodes=[mock_input1, mock_input2],
|
|
output_node=mock_output,
|
|
)
|
|
|
|
def_kernel_hook = kernel.def_kernel("custom_a", "custom_b")
|
|
self.assertEqual(def_kernel_hook, "<DEF_KERNEL>")
|
|
|
|
self.assertIn("<DEF_KERNEL>", kernel.render_hooks)
|
|
|
|
hook_fn = kernel.render_hooks["<DEF_KERNEL>"]
|
|
generated_code = hook_fn()
|
|
|
|
# Check that the generated code contains the expected aliasing statements
|
|
self.assertIn("custom_a = arg_custom_a", generated_code)
|
|
self.assertIn("custom_b = arg_custom_b", generated_code)
|
|
|
|
def test_get_output_hook(self):
|
|
"""Test the get_output() template hook."""
|
|
from torch._inductor.ir import Buffer
|
|
|
|
mock_output = MagicMock(spec=Buffer)
|
|
mock_output.get_name.return_value = "buf_test_output"
|
|
|
|
mock_graph = MockGraphHandler()
|
|
with V.set_graph_handler(mock_graph):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_output",
|
|
input_nodes=[],
|
|
output_node=mock_output,
|
|
)
|
|
|
|
with self.assertRaises(ValueError):
|
|
# error if no output buffer
|
|
result = kernel.get_output()
|
|
|
|
kernel.args.output_buffers["buf_test_output"] = "arg_buf_test_output"
|
|
result = kernel.get_output()
|
|
self.assertEqual(result, "arg_buf_test_output")
|
|
|
|
def test_modification_subgraph(self):
|
|
"""Test the modification() method and subgraph processing."""
|
|
|
|
from torch._inductor.ir import Buffer
|
|
|
|
mock_subgraph1 = MagicMock(spec=Buffer)
|
|
mock_subgraph2 = MagicMock(spec=Buffer)
|
|
subgraphs = [mock_subgraph1, mock_subgraph2]
|
|
|
|
mock_output = MagicMock(spec=Buffer)
|
|
mock_output.get_name.return_value = "buf_output"
|
|
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_modification",
|
|
input_nodes=[],
|
|
output_node=mock_output,
|
|
subgraphs=subgraphs,
|
|
)
|
|
|
|
result = kernel._get_subgraph(0)
|
|
self.assertEqual(result, mock_subgraph1)
|
|
|
|
result = kernel._get_subgraph(1)
|
|
self.assertEqual(result, mock_subgraph2)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
kernel._get_subgraph(2)
|
|
|
|
def test_cutedsl_op_overrides(self):
|
|
"""Test the new CuteDSLOpOverrides class."""
|
|
import torch
|
|
from torch._inductor.codegen.common import CSEVariable
|
|
from torch._inductor.codegen.cutedsl.cutedsl_op_overrides import (
|
|
CuteDSLOpOverrides,
|
|
)
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
mock_cse_a = MagicMock(spec=CSEVariable)
|
|
mock_cse_a.__str__.return_value = "tensor_a"
|
|
mock_cse_a.dtype = torch.float32
|
|
mock_cse_a.bounds = ValueRanges.unknown()
|
|
|
|
mock_cse_b = MagicMock(spec=CSEVariable)
|
|
mock_cse_b.__str__.return_value = "tensor_b"
|
|
mock_cse_b.dtype = torch.float32
|
|
mock_cse_b.bounds = ValueRanges.unknown()
|
|
|
|
mock_graph = MockGraphHandler()
|
|
with V.set_graph_handler(mock_graph):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_ops",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
with V.set_kernel_handler(kernel):
|
|
result = CuteDSLOpOverrides.add(mock_cse_a, mock_cse_b)
|
|
self.assertIsInstance(result, CSEVariable)
|
|
|
|
result = CuteDSLOpOverrides.mul(mock_cse_a, mock_cse_b)
|
|
self.assertIsInstance(result, CSEVariable)
|
|
|
|
result = CuteDSLOpOverrides.truediv(mock_cse_a, mock_cse_b)
|
|
self.assertIsInstance(result, CSEVariable)
|
|
|
|
result = CuteDSLOpOverrides.exp(mock_cse_a)
|
|
self.assertIsInstance(result, CSEVariable)
|
|
|
|
result = CuteDSLOpOverrides.sqrt(mock_cse_a)
|
|
self.assertIsInstance(result, CSEVariable)
|
|
|
|
with self.assertRaises(NotImplementedError):
|
|
result = CuteDSLOpOverrides.maximum(mock_cse_a, mock_cse_b)
|
|
result = CuteDSLOpOverrides.minimum(mock_cse_a, mock_cse_b)
|
|
|
|
scalar_result = CuteDSLOpOverrides._ensure_tensor_ssa("5.0", mock_cse_a)
|
|
self.assertEqual(scalar_result, "cute.full_like(tensor_a, 5.0)")
|
|
|
|
tensor_result = CuteDSLOpOverrides._ensure_tensor_ssa(mock_cse_a, mock_cse_b)
|
|
self.assertEqual(tensor_result, "tensor_a")
|
|
|
|
def test_cse_integration(self):
|
|
"""Test CSE (Common Subexpression Elimination) integration."""
|
|
from torch._inductor.codegen.common import CSE
|
|
|
|
mock_graph = MockGraphHandler()
|
|
with V.set_graph_handler(mock_graph):
|
|
kernel = CuteDSLTemplateKernel(
|
|
kernel_name="test_cse",
|
|
input_nodes=[],
|
|
output_node=None,
|
|
)
|
|
|
|
self.assertIsInstance(kernel.cse, CSE)
|
|
self.assertEqual(kernel.cse.name_prefix, "tmp")
|
|
|
|
with V.set_kernel_handler(kernel):
|
|
test_expr = "x"
|
|
var = kernel.cse.generate(kernel.body, test_expr, dtype=None)
|
|
self.assertTrue(str(var).startswith("tmp"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
|
|
run_tests()
|