[Cutlass 3.2.2 submodule upgrade] Adapt Inductor cutlass backend to Cutlass 3.2.2 (#112762)

The inductor cutlass backend was written against Cutlass version 3.1.x,
there are some incompatible changes in Cutlass 3.2.2 which the
Inductor cutlass backend needs to adapt to.

Test plan:

If third_party/cutlass is upgraded to Cutlass tag v3.2.2,
several tests within test/inductor/test_max_autotune.py start to
fail. With this diff applied, they pass again.

Differential Revision: [D50986555](https://our.internmc.facebook.com/intern/diff/D50986555)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112762
Approved by: https://github.com/ipiszy, https://github.com/drisspg
This commit is contained in:
Kai Londenberg
2023-11-03 13:41:13 -07:00
committed by PyTorch MergeBot
parent 8f10a2321d
commit e36dba3a94
4 changed files with 67 additions and 83 deletions

View File

@ -479,6 +479,17 @@ if(LINUX)
endif()
if(MSVC)
# MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
# because MSVC is not a completely compliant implementation. This option forces MSVC to use the
# appropriate value given the requested --std option. This fixes a compilation issue mismatch
# between GCC/Clang and MSVC.
#
# See:
# * https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170
# * https://en.cppreference.com/w/cpp/preprocessor/replace#Predefined_macros
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
set(CMAKE_NINJA_CMCLDEPS_RC OFF)
foreach(flag_var
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE

View File

@ -17,66 +17,37 @@ from .cuda_env import get_cuda_arch, get_cuda_version
log = logging.getLogger(__name__)
def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
for cutlass_module in cutlass_modules:
content = content.replace(
f"from {cutlass_module} import ", f"from cutlass_{cutlass_module} import "
)
return content
def _gen_cutlass_file(
file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str
) -> None:
orig_full_path = os.path.abspath(os.path.join(src_dir, file_name))
text = ""
with open(orig_full_path) as f:
text = f.read()
text = _rename_cutlass_import(text, cutlass_modules)
dst_full_path = os.path.abspath(
os.path.join(
dst_dir,
f"cutlass_{file_name}" if file_name != "__init__.py" else file_name,
)
)
with open(dst_full_path, "w") as f:
f.write(text)
@functools.lru_cache(None)
def try_import_cutlass() -> bool:
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
# This is a temporary hack to avoid CUTLASS module naming conflicts.
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
cutlass_py_full_path = os.path.join(
inductor_cuda_config.cutlass_dir, "tools/library/scripts"
cutlass_py_full_path = os.path.abspath(
os.path.join(inductor_cuda_config.cutlass_dir, "python/cutlass_library")
)
tmp_cutlass_py_full_path = os.path.abspath(
os.path.join(cache_dir(), "torch_cutlass_script")
os.path.join(cache_dir(), "torch_cutlass_library")
)
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
if os.path.isdir(cutlass_py_full_path):
cutlass_file_names = [
file_name
for file_name in os.listdir(cutlass_py_full_path)
if file_name.endswith(".py")
]
cutlass_module_names = [file_name[:-3] for file_name in cutlass_file_names]
if not os.path.isdir(tmp_cutlass_py_full_path):
os.mkdir(tmp_cutlass_py_full_path)
for file_name in cutlass_file_names:
_gen_cutlass_file(
file_name,
cutlass_module_names,
cutlass_py_full_path,
tmp_cutlass_py_full_path,
)
sys.path.append(tmp_cutlass_py_full_path)
if tmp_cutlass_py_full_path not in sys.path:
if os.path.exists(dst_link):
assert os.path.islink(
dst_link
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
cutlass_py_full_path
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
else:
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
os.symlink(cutlass_py_full_path, dst_link)
sys.path.append(tmp_cutlass_py_full_path)
try:
import cutlass_generator # type: ignore[import] # noqa: F401
import cutlass_library # type: ignore[import] # noqa: F401
import cutlass_manifest # type: ignore[import] # noqa: F401
import cutlass_library.generator # type: ignore[import] # noqa: F401
import cutlass_library.library # type: ignore[import] # noqa: F401
import cutlass_library.manifest # type: ignore[import] # noqa: F401
return True
@ -136,18 +107,14 @@ class CUTLASSArgs:
@functools.lru_cache(None)
def gen_ops() -> List[Any]:
"""
Generates all supported CUTLASS operations.
"""
def _gen_ops_cached(arch, version) -> List[Any]:
# Note: Cache needs to be specific for cuda architecture and version
# Import cutlass python scripts.
assert try_import_cutlass()
import cutlass_generator # type: ignore[import]
import cutlass_manifest # type: ignore[import]
import cutlass_library.generator as cutlass_generator # type: ignore[import]
import cutlass_library.manifest as cutlass_manifest # type: ignore[import]
arch = get_cuda_arch()
version = get_cuda_version()
if arch is None or version is None:
log.error(
"Cannot detect cuda arch %s or cuda version %s. "
@ -172,13 +139,21 @@ def gen_ops() -> List[Any]:
raise NotImplementedError(
"Arch " + arch + " is not supported by current cutlass lib."
) from e
return manifest.operations
def gen_ops() -> List[Any]:
"""
Generates all supported CUTLASS operations.
"""
arch = get_cuda_arch()
version = get_cuda_version()
return _gen_ops_cached(arch, version)
def dtype_match(
torch_dtype: torch.dtype,
cutlass_dtype: "cutlass_library.DataType", # type: ignore[name-defined]
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined]
) -> bool:
# Import cutlass python scripts.
assert try_import_cutlass()
@ -186,13 +161,13 @@ def dtype_match(
if torch_dtype == torch.float:
return (
cutlass_dtype == cutlass_library.DataType.f32
or cutlass_dtype == cutlass_library.DataType.tf32
cutlass_dtype == cutlass_library.library.DataType.f32
or cutlass_dtype == cutlass_library.library.DataType.tf32
)
elif torch_dtype == torch.half:
return cutlass_dtype == cutlass_library.DataType.f16
return cutlass_dtype == cutlass_library.library.DataType.f16
elif torch_dtype == torch.bfloat16:
return cutlass_dtype == cutlass_library.DataType.bf16
return cutlass_dtype == cutlass_library.library.DataType.bf16
else:
return False

View File

@ -131,7 +131,6 @@ GEMM_ARGS_CUTLASS_3X = r"""
};
"""
GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
{
{ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename ThreadEpilogueOp::Params thread
@ -183,7 +182,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
@staticmethod
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
import cutlass_library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
if torch_layout.stride[-1] == 1:
return cutlass_lib.LayoutType.RowMajor
@ -197,7 +196,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
import cutlass_library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
return cutlass_lib.LayoutType.ColumnMajor
@ -220,7 +219,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
@staticmethod
def has_tma_epilogue(op) -> bool:
assert cutlass_utils.try_import_cutlass()
import cutlass_library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
result = False
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
@ -233,8 +232,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
) -> Tuple[str, str]:
assert cutlass_utils.try_import_cutlass()
import cutlass_gemm_operation as cutlass_gemm_op # type: ignore[import]
import cutlass_library as cutlass_lib # type: ignore[import]
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
@ -291,7 +290,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
) -> "cutlass_gemm_op.GemmOperation": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
import cutlass_library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
# Skip simt kernels
if (
@ -306,7 +305,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
cutlass_lib.GemmKind.Universal3x,
}:
return None
# Filter ops by dtypes.
X = self.input_nodes[0]
W = self.input_nodes[1]
@ -372,21 +370,22 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
assert cutlass_utils.try_import_cutlass()
import cutlass_gemm_operation as cutlass_gemm_op # type: ignore[import]
import cutlass_library as cutlass_lib # type: ignore[import]
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
num_3x_ops = 0
num_2x_ops = 0
for op_list in ops.values():
for op in op_list:
filter_res = self.filter_op(op)
if (
filter_res is not None
and res.get(filter_res.configuration_name(), None) is None
):
res[filter_res.configuration_name()] = filter_res
for op_dict in ops.values():
for op_list in op_dict.values():
for op in op_list:
filter_res = self.filter_op(op)
if (
filter_res is not None
and res.get(filter_res.configuration_name(), None) is None
):
res[filter_res.configuration_name()] = filter_res
for op in res.values():
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
num_3x_ops += 1
@ -481,7 +480,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
output_node: Optional[Buffer] = None,
) -> str:
assert cutlass_utils.try_import_cutlass()
import cutlass_library as cutlass_lib # type: ignore[import]
import cutlass_library.library as cutlass_lib # type: ignore[import]
if output_node is not None:
self.output_node = output_node
@ -523,6 +522,5 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
instance_type=instance_type,
input_reorder=self.input_reorder,
)
res = self._template_from_string(GEMM_TEMPLATE).render(**options)
return res