Compare commits

...

1 Commits

Author SHA1 Message Date
fa00b8cdba my local copy of pr #137410 2024-10-22 11:09:18 -07:00
6 changed files with 35 additions and 18 deletions

View File

@ -2263,7 +2263,8 @@ class KernelTemplate:
"""
try:
choices.append(self.generate(**kwargs))
choice = self.generate(**kwargs)
choices.append(choice)
except NotImplementedError as e:
pass

View File

@ -125,6 +125,7 @@ class CUTLASSArgs:
generator_target = ""
kernels = "all"
ignore_kernels = ""
exclude_kernels = ""
# TODO: these three look dead?
kernel_filter_file: None = None
selected_kernel_list: None = None
@ -160,6 +161,7 @@ def _gen_ops_cached(arch, version) -> List[Any]:
return []
arch = _normalize_cuda_arch(arch)
args = CUTLASSArgs(architectures=arch, cuda_version=version)
#import pdb; pdb.set_trace()
manifest = cutlass_manifest.Manifest(args)
if arch == "90":
@ -255,12 +257,7 @@ def get_accumulator_dtype(
]:
torch_dtype = dtype0
if torch_dtype == torch.half:
if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction:
return torch_dtype
else:
return torch.float
if torch_dtype in {torch.bfloat16, torch.float}:
if torch_dtype in {torch.float16, torch.bfloat16, torch.float}:
return torch.float
if torch_dtype == torch.int8:
return torch.int32

View File

@ -47,7 +47,7 @@ PT_EXPORT {{kernel_call_signature}} {
CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count);
}
{{instance_type}}::Arguments arguments;
{{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw,
{{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, swizzle,
X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}}
{{instance_type}} gemm_op;
if (workspace_size) {
@ -123,6 +123,9 @@ GEMM_ARGS_CUTLASS_3X = r"""
{{epilogue_arguments}},
hw_info
};
arguments.scheduler.max_swizzle_size = {{swizzle}};
"""
# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied,
@ -505,11 +508,17 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
"""
ops = self.gen_ops()
for op in ops:
self.maybe_append_choice(
choices,
op=op,
)
for name, op in ops:
old_len = len(choices)
for swizzle in [1, 2, 4, 8]:
# breakpoint()
self.maybe_append_choice(
choices,
op=op,
swizzle=swizzle
)
if len(choices) > old_len:
choices[-1].debug_extra = name + f" swizzle={swizzle}"
if len(ops) == 0:
input_layouts = [node.get_layout() for node in input_nodes]
input_strides = [node.get_stride() for node in input_nodes]
@ -793,6 +802,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
# TODO: update epilogue functor according to epilogues.
op.element_epilogue = op.accumulator_type()
if inductor_cuda_config.cutlass_op_allowlist_regex is not None:
#print("boo:", op.configuration_name())
if not re.search(
inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name()
):
@ -826,18 +836,23 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
res: Dict[str, cutlass_gemm_op.GemmOperation] = {}
#from pprint import pprint
#pprint(ops)
for op_dict in ops.values():
for op_list in op_dict.values():
for key, op_list in op_dict.items():
for op in op_list:
assert isinstance(op, cutlass_gemm_op.GemmOperation)
filter_res = self.filter_op(op)
# if "ping" not in key:
# continue
if (
filter_res is not None
and res.get(filter_res.configuration_name(), None) is None
):
res[filter_res.configuration_name()] = filter_res
log.debug("Got cutlass configs: total number of ops: %d, ", len(res))
return list(res.values())[: inductor_cuda_config.cutlass_max_profiling_configs]
# breakpoint()
return list(res.items())[: inductor_cuda_config.cutlass_max_profiling_configs]
def gemm_mode(self) -> str:
"""
@ -952,6 +967,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
Bias=Bias,
epilogue_template=epilogue_template,
argument_template=argument_template,
swizzle=kwargs['swizzle'],
should_swap_xw=should_swap_xw,
template=self,
kernel=kernel,
@ -1216,6 +1232,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
argument_template: str,
epilogue_template: str,
should_swap_xw: bool,
swizzle: int,
X: IRNode,
W: IRNode,
Bias: IRNode,
@ -1261,6 +1278,7 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
M="M",
N="N",
epilogue_args=epilogue_args,
swizzle=swizzle,
)
assert epilogue_template is not None

View File

@ -1104,7 +1104,7 @@ class cuda:
# Set this to "pingpong" to avoid numerical issues
# caused by the op ordering of the "pingpong" memory access
# pattern used by some Cutlass Kernels.
cutlass_op_denylist_regex: Optional[str] = "pingpong"
cutlass_op_denylist_regex: Optional[str] = None
class rocm:

View File

@ -194,6 +194,7 @@ class TritonBenchmarker(Benchmarker):
this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified,
this is the requested return mode. Otherwise, this is the median.
"""
kwargs['rep'] = 400
if "quantiles" in kwargs:
return self.triton_do_bench(_callable, **kwargs)[0]
elif "return_mode" in kwargs:

View File

@ -33,7 +33,7 @@ try:
attrs_descriptor_available = True
# Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
attr_desc_fields = set(dir(AttrsDescriptor))
ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
divisible_by_8_available = "divisible_by_8" in attr_desc_fields
except ImportError: