mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move negative index checking to common.py - Fix issue 97365 (#108690)
Fixes https://github.com/pytorch/pytorch/issues/97365 Pull Request resolved: https://github.com/pytorch/pytorch/pull/108690 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
f32eb9bc55
commit
f97c2dabd9
@ -813,6 +813,73 @@ class CommonTemplate:
|
||||
self.assertEqual(expect, actual)
|
||||
self.assertEqual(actual, repeat(x, 3))
|
||||
|
||||
@skipIfRocm
|
||||
def test_neg_index(self):
|
||||
def test(fn, inps, has_assert: bool, has_wrapping: bool):
|
||||
for dynamic in (True, False):
|
||||
fn_opt = torch.compile(dynamic=dynamic)(fn)
|
||||
if self.device == "cpu":
|
||||
code = run_and_get_cpp_code(fn_opt, *inps)
|
||||
found = False
|
||||
# match ternary operator
|
||||
pattern = r"\?.*:"
|
||||
if re.findall(pattern, code):
|
||||
found = True
|
||||
self.assertTrue(found is has_wrapping)
|
||||
self.assertTrue(("TORCH_CHECK" in code) is has_assert)
|
||||
else:
|
||||
code = run_and_get_triton_code(fn_opt, *inps)
|
||||
self.assertTrue(("tl.where" in code) is has_wrapping)
|
||||
self.assertTrue(("device_assert" in code) is has_assert)
|
||||
self.assertEqual(fn(*inps), fn_opt(*inps))
|
||||
|
||||
def indirect(a, b):
|
||||
return a[b - 1]
|
||||
|
||||
a = torch.rand(1024, device=self.device)
|
||||
b = torch.zeros(4, dtype=torch.long, device=self.device)
|
||||
test(indirect, (a, b), has_assert=True, has_wrapping=True)
|
||||
|
||||
def direct(x):
|
||||
return x[:, -1]
|
||||
|
||||
a = torch.rand(1, 64, 32, device=self.device)
|
||||
test(direct, (a,), has_assert=False, has_wrapping=False)
|
||||
|
||||
def flip(a, b):
|
||||
return a[b]
|
||||
|
||||
a = torch.rand(1024, device=self.device)
|
||||
b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device=self.device)
|
||||
test(flip, (a, b), has_assert=True, has_wrapping=True)
|
||||
|
||||
# Constant propagate a constant that's negative
|
||||
def flip_with_index_constant(a):
|
||||
b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device=self.device)
|
||||
return a[b]
|
||||
|
||||
# Wrapping is constant-folded
|
||||
test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
|
||||
|
||||
# Operation where we can't prove that the index is always positive or negative
|
||||
def pos_and_neg(a):
|
||||
b = torch.arange(start=1, end=-a.numel() - 1, step=-1, device=self.device)
|
||||
return a[b]
|
||||
|
||||
# It has wrapping but no assert
|
||||
test(pos_and_neg, (a,), has_assert=False, has_wrapping=True)
|
||||
|
||||
# We currently don't do constant propagation with float constants
|
||||
def flip_with_index(a):
|
||||
b = 1.0 * torch.arange(
|
||||
start=-1, end=-a.numel() - 1, step=-1, device=self.device
|
||||
)
|
||||
b = b.int()
|
||||
return a[b]
|
||||
|
||||
# Constant is propagated as we can prove that the result is always negative.
|
||||
test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
|
||||
|
||||
def test_computed_buffer_inlining(self):
|
||||
def flip(x):
|
||||
idx = torch.arange(x.size(0) - 1, -1, -1, device=x.device)
|
||||
@ -7473,63 +7540,6 @@ if HAS_CUDA and not TEST_WITH_ASAN:
|
||||
# there are a couple extra tensors created in `insertable_tensor_check`
|
||||
self.assertTrue(max_live_tensors == 4)
|
||||
|
||||
@skipIfRocm
|
||||
def test_neg_index(self):
|
||||
def test(fn, inps, has_assert: bool, has_wrapping=True):
|
||||
for dynamic in (True, False):
|
||||
fn_opt = torch.compile(dynamic=dynamic)(fn)
|
||||
code = run_and_get_triton_code(fn_opt, *inps)
|
||||
self.assertTrue(("tl.where" in code) is has_wrapping)
|
||||
self.assertTrue(("device_assert" in code) is has_assert)
|
||||
self.assertEqual(fn(*inps), fn_opt(*inps))
|
||||
|
||||
def indirect(a, b):
|
||||
return a[b - 1]
|
||||
|
||||
a = torch.rand(1024, device="cuda")
|
||||
b = torch.zeros(4, dtype=torch.long, device="cuda")
|
||||
test(indirect, (a, b), has_assert=True)
|
||||
|
||||
def direct(x):
|
||||
return x[:, -1]
|
||||
|
||||
a = torch.rand(1, 64, 32, device="cuda")
|
||||
test(direct, (a,), has_assert=False, has_wrapping=False)
|
||||
|
||||
def flip(a, b):
|
||||
return a[b]
|
||||
|
||||
a = torch.rand(1024, device="cuda")
|
||||
b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device="cuda")
|
||||
test(flip, (a, b), has_assert=True)
|
||||
|
||||
# Constant propagate a constant that's negative
|
||||
def flip_with_index_constant(a):
|
||||
b = torch.arange(start=-1, end=-a.numel() - 1, step=-1, device="cuda")
|
||||
return a[b]
|
||||
|
||||
# Wrapping is constant-folded
|
||||
test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
|
||||
|
||||
# Operation where we can't prove that the index is always positive or negative
|
||||
def pos_and_neg(a):
|
||||
b = torch.arange(start=1, end=-a.numel() - 1, step=-1, device="cuda")
|
||||
return a[b]
|
||||
|
||||
# It has wrapping but no assert
|
||||
test(pos_and_neg, (a,), has_assert=False, has_wrapping=True)
|
||||
|
||||
# We currently don't do constant propagation with float constants
|
||||
def flip_with_index(a):
|
||||
b = 1.0 * torch.arange(
|
||||
start=-1, end=-a.numel() - 1, step=-1, device="cuda"
|
||||
)
|
||||
b = b.int()
|
||||
return a[b]
|
||||
|
||||
# Constant is propagated as we can prove that the result is always negative.
|
||||
test(flip_with_index_constant, (a,), has_assert=False, has_wrapping=False)
|
||||
|
||||
# See https://github.com/pytorch/pytorch/issues/100348
|
||||
def test_inductor_detach_view(self):
|
||||
def fn(x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -1088,10 +1088,12 @@ def get_include_and_linking_paths(
|
||||
lpaths = cpp_extension.library_paths(cuda) + [
|
||||
sysconfig.get_config_var("LIBDIR")
|
||||
]
|
||||
|
||||
libs = []
|
||||
|
||||
# No need to manually specify libraries in fbcode.
|
||||
if not config.is_fbcode():
|
||||
libs += ["c10", "torch", "torch_cpu"]
|
||||
libs += ["torch", "torch_cpu"]
|
||||
libs += ["gomp"]
|
||||
if not aot_mode:
|
||||
libs += ["torch_python"]
|
||||
@ -1187,6 +1189,10 @@ def get_include_and_linking_paths(
|
||||
else:
|
||||
libs = ["omp"] if config.is_fbcode() else ["gomp"]
|
||||
|
||||
# Unconditionally import c10 to use TORCH_CHECK - See PyTorch #108690
|
||||
libs += ["c10"]
|
||||
lpaths += [cpp_extension.TORCH_LIB_PATH]
|
||||
|
||||
# third party libs
|
||||
if config.is_fbcode():
|
||||
ipaths.append(build_paths.sleef())
|
||||
|
@ -7,7 +7,18 @@ import operator
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Set, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sympy
|
||||
from sympy.printing.printer import Printer
|
||||
@ -24,6 +35,7 @@ from ..utils import (
|
||||
IndentedBuffer,
|
||||
sympy_dot,
|
||||
sympy_subs,
|
||||
sympy_symbol,
|
||||
unique,
|
||||
)
|
||||
from ..virtualized import ops, OpsValue, V
|
||||
@ -794,6 +806,49 @@ class CSE:
|
||||
return var
|
||||
|
||||
|
||||
class IndirectAssertLine(DeferredLineBase):
|
||||
def __init__(self, line, assert_fn, var, mask, size_map):
|
||||
self.var = var
|
||||
self.mask = mask
|
||||
self.line = line
|
||||
self.assert_fn = assert_fn
|
||||
self.size_map = size_map
|
||||
|
||||
def __call__(self):
|
||||
size, size_str = self.size_map[(self.var, self.mask)]
|
||||
|
||||
# We assert if we've not been able to prove the bound
|
||||
assert_min = (self.var.bounds.lower >= 0) != sympy.true
|
||||
assert_max = (self.var.bounds.upper < size) != sympy.true
|
||||
|
||||
# FooBar interview question
|
||||
if not (assert_min or assert_max):
|
||||
return None
|
||||
elif assert_min and assert_max:
|
||||
# The conditions need to be in parens because of Python's operator precedence.
|
||||
# It'd be less error-prone to use and/or/not, which is suported by triton
|
||||
cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
|
||||
cond_print = f"0 <= {self.var} < {size_str}"
|
||||
elif assert_min:
|
||||
cond = f"0 <= {self.var}"
|
||||
cond_print = cond
|
||||
else:
|
||||
assert assert_max
|
||||
cond = f"{self.var} < {size_str}"
|
||||
cond_print = cond
|
||||
|
||||
if self.mask:
|
||||
cond = f"({cond}) | ~{self.mask}"
|
||||
return self.line.format(
|
||||
assert_fn=self.assert_fn, cond=cond, cond_print=cond_print
|
||||
)
|
||||
|
||||
def _new_line(self, line):
|
||||
return IndirectAssertLine(
|
||||
line, self.assert_fn, self.var, self.mask, self.size_map
|
||||
)
|
||||
|
||||
|
||||
class CodeGen:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -825,9 +880,12 @@ class Kernel(CodeGen):
|
||||
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
|
||||
self.must_keep_buffers = set()
|
||||
self.store_buffer_names = set()
|
||||
self._load_mask = None
|
||||
# set in set_current_node
|
||||
self.current_node = None
|
||||
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges]] = None
|
||||
# Upper bounds for indirect_indexing and their str representation
|
||||
self.indirect_max_sizes: Dict[Tuple[str, str], Tuple[sympy.Expr, str]] = {}
|
||||
|
||||
self.removed_buffers = set()
|
||||
# key: the buffer to write
|
||||
@ -929,9 +987,61 @@ class Kernel(CodeGen):
|
||||
return inner
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(index_var, size, check=True):
|
||||
def indirect_indexing(var, size, check=True):
|
||||
# Skip CSE since this doesn't return an expression
|
||||
return self.indirect_indexing(index_var, size, check) # type: ignore[attr-defined]
|
||||
|
||||
if var.bounds.lower < 0:
|
||||
new_bounds = ValueRanges.unknown()
|
||||
if var.bounds != ValueRanges.unknown() and isinstance(
|
||||
size, sympy.Number
|
||||
):
|
||||
# Take the negative part of the bound and add size to it
|
||||
# Then take union of that and the positive part
|
||||
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
|
||||
neg = var.bounds & ValueRanges(-sympy.oo, -1)
|
||||
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
|
||||
# We don't have a good way of representing the empty range
|
||||
if var.bounds.upper >= 0:
|
||||
pos = var.bounds & ValueRanges(0, sympy.oo)
|
||||
new_bounds = new_bounds | pos
|
||||
|
||||
stm = ops.add(var, self.rename_indexing(size))
|
||||
# Mixed negative and non-negative
|
||||
if var.bounds.upper >= 0:
|
||||
lt = ops.lt(var, "0")
|
||||
stm = ops.where(lt, stm, var)
|
||||
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
||||
|
||||
new_var.update_on_args("index_wrap", (var,), {})
|
||||
var = new_var
|
||||
|
||||
if self.generate_assert(check):
|
||||
mask = self.load_mask(var)
|
||||
|
||||
# An assertion line may have been written already, if so just
|
||||
# update the max size.
|
||||
map_key = (var, mask)
|
||||
existing_size, _ = self.indirect_max_sizes.get(
|
||||
map_key, (None, None)
|
||||
)
|
||||
if existing_size is not None:
|
||||
size = sympy.Min(size, existing_size)
|
||||
else:
|
||||
line = (
|
||||
'{assert_fn}({cond}, "index out of bounds: {cond_print}")'
|
||||
)
|
||||
self.compute.writeline(
|
||||
IndirectAssertLine(
|
||||
line,
|
||||
self.assert_function, # type: ignore[attr-defined]
|
||||
var,
|
||||
mask,
|
||||
self.indirect_max_sizes,
|
||||
)
|
||||
)
|
||||
|
||||
self.indirect_max_sizes[map_key] = (size, self.index_to_str(size)) # type: ignore[attr-defined]
|
||||
return sympy_symbol(str(var))
|
||||
|
||||
@staticmethod
|
||||
def load(name: str, index: sympy.Expr):
|
||||
@ -1014,6 +1124,13 @@ class Kernel(CodeGen):
|
||||
V.graph.scheduler.remove_kernel_local_buffers()
|
||||
super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def generate_assert(self, check):
|
||||
return (check or config.debug_index_asserts) and config.assert_indirect_indexing
|
||||
|
||||
def load_mask(self, var):
|
||||
# only the triton kernel requires mask
|
||||
return ""
|
||||
|
||||
def rename_indexing(self, index) -> sympy.Expr:
|
||||
# adds the necessary kernel args for index expressions
|
||||
# and renames variables in index expressions to kernel arg names
|
||||
|
@ -1199,7 +1199,6 @@ class CppKernel(Kernel):
|
||||
self.poststores = IndentedBuffer()
|
||||
self.num_threads = num_threads # num_threads the kernel specialized for
|
||||
self.reduction_omp_dec: Dict[Tuple[str, str], str] = {}
|
||||
self._load_mask = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def masked(self, mask):
|
||||
@ -1222,9 +1221,12 @@ class CppKernel(Kernel):
|
||||
new_index = sympy_subs(index, replacement)
|
||||
return new_index
|
||||
|
||||
@staticmethod
|
||||
def indirect_indexing(index_var, size, check=True):
|
||||
return sympy_symbol(str(index_var))
|
||||
def index_to_str(self, index: sympy.Expr) -> str:
|
||||
"""
|
||||
Convert an index expr to a string that can be used in cpp code.
|
||||
e.g. a sympy expression "s2" may actually appear as "ks1" in the cpp kernel.
|
||||
"""
|
||||
return cexpr(self.rename_indexing(index))
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
var = self.args.input(name)
|
||||
@ -1420,6 +1422,10 @@ class CppKernel(Kernel):
|
||||
loop_nest = LoopNestWithSplit.build(self)
|
||||
self.codegen_loops_impl(loop_nest, code, worksharing)
|
||||
|
||||
@property
|
||||
def assert_function(self):
|
||||
return "TORCH_CHECK"
|
||||
|
||||
def decide_parallel_depth(self, ranges, threads):
|
||||
seq = self.size_hint()
|
||||
par = 1
|
||||
|
@ -27,7 +27,6 @@ from ..optimize_indexing import indexing_dtype_strength_reduction
|
||||
from ..scheduler import BaseScheduling
|
||||
from ..triton_heuristics import AutotuneHint
|
||||
from ..utils import (
|
||||
DeferredLineBase,
|
||||
get_fused_kernel_name,
|
||||
get_kernel_metadata,
|
||||
green_text,
|
||||
@ -824,7 +823,6 @@ class TritonKernel(Kernel):
|
||||
self.range_tree_nodes = {}
|
||||
self.iter_vars_count = itertools.count()
|
||||
self.inside_reduction = self.numels[-1] != 1
|
||||
self._load_mask = None
|
||||
self.body = IndentedBuffer()
|
||||
self.indexing_code = IndentedBuffer()
|
||||
self.suffix: IndentedBuffer = IndentedBuffer() # type: ignore[assignment]
|
||||
@ -832,8 +830,6 @@ class TritonKernel(Kernel):
|
||||
self.reduction_hint = reduction_hint
|
||||
self.index_dtype = index_dtype
|
||||
self.min_elem_per_thread = min_elem_per_thread
|
||||
# Upper bounds for indirect_indexing and their str representation
|
||||
self.indirect_max_sizes: Dict[Tuple[str, str], Tuple[sympy.Expr, str]] = {}
|
||||
self.last_usage = set()
|
||||
|
||||
self.persistent_reduction = self.should_use_persistent_reduction()
|
||||
@ -1250,101 +1246,26 @@ class TritonKernel(Kernel):
|
||||
finally:
|
||||
self._load_mask = prior
|
||||
|
||||
def indirect_indexing(self, var, size, check=True):
|
||||
# TODO(lezcano) This code should be lifted to codegen/common.py.
|
||||
# This should be easy, as now CSE variables carry bounds info
|
||||
class IndirectAssertLine(DeferredLineBase):
|
||||
def __init__(self, line, var, mask, size_map):
|
||||
self.var = var
|
||||
self.mask = mask
|
||||
self.line = line
|
||||
self.size_map = size_map
|
||||
def generate_assert(self, check):
|
||||
return torch.version.hip is None and super().generate_assert(check)
|
||||
|
||||
def __call__(self):
|
||||
size, size_str = self.size_map[(self.var, self.mask)]
|
||||
def load_mask(self, var):
|
||||
mask = ""
|
||||
mask_vars = set(var.mask_vars)
|
||||
if self._load_mask:
|
||||
mask_vars.add(self._load_mask)
|
||||
|
||||
# We assert if we've not been able to prove the bound
|
||||
assert_min = (self.var.bounds.lower >= 0) != sympy.true
|
||||
assert_max = (self.var.bounds.upper < size) != sympy.true
|
||||
if mask_vars:
|
||||
mask = (
|
||||
f"{list(mask_vars)[0]}"
|
||||
if len(mask_vars) == 1
|
||||
else f"({' & '.join(str(v) for v in mask_vars)})"
|
||||
)
|
||||
return mask
|
||||
|
||||
# FooBar interview question
|
||||
if not (assert_min or assert_max):
|
||||
return None
|
||||
elif assert_min and assert_max:
|
||||
# The conditions need to be in parens because of Python's operator precedence.
|
||||
# It'd be less error-prone to use and/or/not, which is supported by triton
|
||||
cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
|
||||
cond_print = f"0 <= {self.var} < {size_str}"
|
||||
elif assert_min:
|
||||
cond = f"0 <= {self.var}"
|
||||
cond_print = cond
|
||||
else:
|
||||
assert assert_max
|
||||
cond = f"{self.var} < {size_str}"
|
||||
cond_print = cond
|
||||
|
||||
if self.mask:
|
||||
cond = f"({cond}) | ~{self.mask}"
|
||||
return self.line.format(cond=cond, cond_print=cond_print)
|
||||
|
||||
def _new_line(self, line):
|
||||
return IndirectAssertLine(line, self.var, self.mask, self.size_map)
|
||||
|
||||
if var.bounds.lower < 0:
|
||||
new_bounds = ValueRanges.unknown()
|
||||
if var.bounds != ValueRanges.unknown() and isinstance(size, sympy.Number):
|
||||
# Take the negative part of the bound and add size to it
|
||||
# Then take union of that and the positive part
|
||||
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
|
||||
neg = var.bounds & ValueRanges(-sympy.oo, -1)
|
||||
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
|
||||
# We don't have a good way of representing the empty range
|
||||
if var.bounds.upper >= 0:
|
||||
pos = var.bounds & ValueRanges(0, sympy.oo)
|
||||
new_bounds = new_bounds | pos
|
||||
|
||||
stm = f"{var} + {self.index_to_str(size)}"
|
||||
# Mixed negative and non-negative
|
||||
if var.bounds.upper >= 0:
|
||||
stm = f"tl.where({var} < 0, {stm}, {var})"
|
||||
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
||||
|
||||
new_var.update_on_args("index_wrap", (var,), {})
|
||||
var = new_var
|
||||
|
||||
generate_assert = (
|
||||
(check or config.debug_index_asserts)
|
||||
and config.triton.assert_indirect_indexing
|
||||
and torch.version.hip is None
|
||||
)
|
||||
if generate_assert:
|
||||
mask_vars = set(var.mask_vars)
|
||||
if self._load_mask:
|
||||
mask_vars.add(self._load_mask)
|
||||
|
||||
mask = ""
|
||||
if mask_vars:
|
||||
mask = (
|
||||
f"{list(mask_vars)[0]}"
|
||||
if len(mask_vars) == 1
|
||||
else f"({' & '.join(str(v) for v in mask_vars)})"
|
||||
)
|
||||
|
||||
# An assertion line may have been written already, if so just
|
||||
# update the max size.
|
||||
map_key = (var, mask)
|
||||
existing_size, _ = self.indirect_max_sizes.get(map_key, (None, None))
|
||||
if existing_size is not None:
|
||||
size = sympy.Min(size, existing_size)
|
||||
else:
|
||||
line = 'tl.device_assert({cond}, "index out of bounds: {cond_print}")'
|
||||
self.compute.writeline(
|
||||
IndirectAssertLine(line, var, mask, self.indirect_max_sizes)
|
||||
)
|
||||
|
||||
self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
|
||||
|
||||
return sympy_symbol(str(var))
|
||||
@property
|
||||
def assert_function(self):
|
||||
return "tl.device_assert"
|
||||
|
||||
def get_strides_of_load(self, index: sympy.Expr):
|
||||
"""
|
||||
|
@ -230,6 +230,9 @@ constant_and_index_propagation = True
|
||||
# performing any constant-inlining optimization
|
||||
always_keep_tensor_constants = False
|
||||
|
||||
# assert that indirect indexing does not read / write out of bounds
|
||||
assert_indirect_indexing = True
|
||||
|
||||
|
||||
def is_fbcode():
|
||||
return not hasattr(torch.version, "git_version")
|
||||
@ -430,9 +433,6 @@ class triton:
|
||||
tiling_prevents_pointwise_fusion = True
|
||||
tiling_prevents_reduction_fusion = True
|
||||
|
||||
# assert that indirect indexing does not read / write out of bounds
|
||||
assert_indirect_indexing = True
|
||||
|
||||
# should we give different names to kernels
|
||||
# Note: This is orthogonal to descriptive_names - this is deciding whether
|
||||
# our triton kernel names should all be `triton_` (to maximize caching) or
|
||||
|
@ -266,7 +266,7 @@ class CachingAutotuner(KernelInterface):
|
||||
compile_meta["num_warps"] = cfg.num_warps
|
||||
compile_meta["num_stages"] = cfg.num_stages
|
||||
compile_meta["debug"] = (
|
||||
config.triton.assert_indirect_indexing and torch.version.hip is None
|
||||
config.assert_indirect_indexing and torch.version.hip is None
|
||||
)
|
||||
|
||||
# Setting device_type="hip" required on ROCm to pass down to triton
|
||||
|
Reference in New Issue
Block a user