mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Cutlass 3.2.2 submodule upgrade] Adapt Inductor cutlass backend to Cutlass 3.2.2 (#112762)
The inductor cutlass backend was written against Cutlass version 3.1.x, there are some incompatible changes in Cutlass 3.2.2 which the Inductor cutlass backend needs to adapt to. Test plan: If third_party/cutlass is upgraded to Cutlass tag v3.2.2, several tests within test/inductor/test_max_autotune.py start to fail. With this diff applied, they pass again. Differential Revision: [D50986555](https://our.internmc.facebook.com/intern/diff/D50986555) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112762 Approved by: https://github.com/ipiszy, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f10a2321d
commit
e36dba3a94
@ -479,6 +479,17 @@ if(LINUX)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(MSVC)
|
if(MSVC)
|
||||||
|
# MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
|
||||||
|
# because MSVC is not a completely compliant implementation. This option forces MSVC to use the
|
||||||
|
# appropriate value given the requested --std option. This fixes a compilation issue mismatch
|
||||||
|
# between GCC/Clang and MSVC.
|
||||||
|
#
|
||||||
|
# See:
|
||||||
|
# * https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170
|
||||||
|
# * https://en.cppreference.com/w/cpp/preprocessor/replace#Predefined_macros
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
|
||||||
|
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
|
||||||
|
|
||||||
set(CMAKE_NINJA_CMCLDEPS_RC OFF)
|
set(CMAKE_NINJA_CMCLDEPS_RC OFF)
|
||||||
foreach(flag_var
|
foreach(flag_var
|
||||||
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
|
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
|
||||||
|
2
third_party/cutlass
vendored
2
third_party/cutlass
vendored
Submodule third_party/cutlass updated: 6f47420213...44c704eae8
@ -17,66 +17,37 @@ from .cuda_env import get_cuda_arch, get_cuda_version
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
|
|
||||||
for cutlass_module in cutlass_modules:
|
|
||||||
content = content.replace(
|
|
||||||
f"from {cutlass_module} import ", f"from cutlass_{cutlass_module} import "
|
|
||||||
)
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
def _gen_cutlass_file(
|
|
||||||
file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str
|
|
||||||
) -> None:
|
|
||||||
orig_full_path = os.path.abspath(os.path.join(src_dir, file_name))
|
|
||||||
text = ""
|
|
||||||
with open(orig_full_path) as f:
|
|
||||||
text = f.read()
|
|
||||||
text = _rename_cutlass_import(text, cutlass_modules)
|
|
||||||
dst_full_path = os.path.abspath(
|
|
||||||
os.path.join(
|
|
||||||
dst_dir,
|
|
||||||
f"cutlass_{file_name}" if file_name != "__init__.py" else file_name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
with open(dst_full_path, "w") as f:
|
|
||||||
f.write(text)
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def try_import_cutlass() -> bool:
|
def try_import_cutlass() -> bool:
|
||||||
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
|
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
|
||||||
# This is a temporary hack to avoid CUTLASS module naming conflicts.
|
# This is a temporary hack to avoid CUTLASS module naming conflicts.
|
||||||
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
|
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
|
||||||
|
|
||||||
cutlass_py_full_path = os.path.join(
|
cutlass_py_full_path = os.path.abspath(
|
||||||
inductor_cuda_config.cutlass_dir, "tools/library/scripts"
|
os.path.join(inductor_cuda_config.cutlass_dir, "python/cutlass_library")
|
||||||
)
|
)
|
||||||
tmp_cutlass_py_full_path = os.path.abspath(
|
tmp_cutlass_py_full_path = os.path.abspath(
|
||||||
os.path.join(cache_dir(), "torch_cutlass_script")
|
os.path.join(cache_dir(), "torch_cutlass_library")
|
||||||
)
|
)
|
||||||
|
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
|
||||||
|
|
||||||
if os.path.isdir(cutlass_py_full_path):
|
if os.path.isdir(cutlass_py_full_path):
|
||||||
cutlass_file_names = [
|
if tmp_cutlass_py_full_path not in sys.path:
|
||||||
file_name
|
if os.path.exists(dst_link):
|
||||||
for file_name in os.listdir(cutlass_py_full_path)
|
assert os.path.islink(
|
||||||
if file_name.endswith(".py")
|
dst_link
|
||||||
]
|
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
||||||
cutlass_module_names = [file_name[:-3] for file_name in cutlass_file_names]
|
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
|
||||||
if not os.path.isdir(tmp_cutlass_py_full_path):
|
cutlass_py_full_path
|
||||||
os.mkdir(tmp_cutlass_py_full_path)
|
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
|
||||||
for file_name in cutlass_file_names:
|
else:
|
||||||
_gen_cutlass_file(
|
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
|
||||||
file_name,
|
os.symlink(cutlass_py_full_path, dst_link)
|
||||||
cutlass_module_names,
|
sys.path.append(tmp_cutlass_py_full_path)
|
||||||
cutlass_py_full_path,
|
|
||||||
tmp_cutlass_py_full_path,
|
|
||||||
)
|
|
||||||
sys.path.append(tmp_cutlass_py_full_path)
|
|
||||||
try:
|
try:
|
||||||
import cutlass_generator # type: ignore[import] # noqa: F401
|
import cutlass_library.generator # type: ignore[import] # noqa: F401
|
||||||
import cutlass_library # type: ignore[import] # noqa: F401
|
import cutlass_library.library # type: ignore[import] # noqa: F401
|
||||||
import cutlass_manifest # type: ignore[import] # noqa: F401
|
import cutlass_library.manifest # type: ignore[import] # noqa: F401
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -136,18 +107,14 @@ class CUTLASSArgs:
|
|||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def gen_ops() -> List[Any]:
|
def _gen_ops_cached(arch, version) -> List[Any]:
|
||||||
"""
|
# Note: Cache needs to be specific for cuda architecture and version
|
||||||
Generates all supported CUTLASS operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Import cutlass python scripts.
|
# Import cutlass python scripts.
|
||||||
assert try_import_cutlass()
|
assert try_import_cutlass()
|
||||||
import cutlass_generator # type: ignore[import]
|
import cutlass_library.generator as cutlass_generator # type: ignore[import]
|
||||||
import cutlass_manifest # type: ignore[import]
|
import cutlass_library.manifest as cutlass_manifest # type: ignore[import]
|
||||||
|
|
||||||
arch = get_cuda_arch()
|
|
||||||
version = get_cuda_version()
|
|
||||||
if arch is None or version is None:
|
if arch is None or version is None:
|
||||||
log.error(
|
log.error(
|
||||||
"Cannot detect cuda arch %s or cuda version %s. "
|
"Cannot detect cuda arch %s or cuda version %s. "
|
||||||
@ -172,13 +139,21 @@ def gen_ops() -> List[Any]:
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Arch " + arch + " is not supported by current cutlass lib."
|
"Arch " + arch + " is not supported by current cutlass lib."
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
return manifest.operations
|
return manifest.operations
|
||||||
|
|
||||||
|
|
||||||
|
def gen_ops() -> List[Any]:
|
||||||
|
"""
|
||||||
|
Generates all supported CUTLASS operations.
|
||||||
|
"""
|
||||||
|
arch = get_cuda_arch()
|
||||||
|
version = get_cuda_version()
|
||||||
|
return _gen_ops_cached(arch, version)
|
||||||
|
|
||||||
|
|
||||||
def dtype_match(
|
def dtype_match(
|
||||||
torch_dtype: torch.dtype,
|
torch_dtype: torch.dtype,
|
||||||
cutlass_dtype: "cutlass_library.DataType", # type: ignore[name-defined]
|
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# Import cutlass python scripts.
|
# Import cutlass python scripts.
|
||||||
assert try_import_cutlass()
|
assert try_import_cutlass()
|
||||||
@ -186,13 +161,13 @@ def dtype_match(
|
|||||||
|
|
||||||
if torch_dtype == torch.float:
|
if torch_dtype == torch.float:
|
||||||
return (
|
return (
|
||||||
cutlass_dtype == cutlass_library.DataType.f32
|
cutlass_dtype == cutlass_library.library.DataType.f32
|
||||||
or cutlass_dtype == cutlass_library.DataType.tf32
|
or cutlass_dtype == cutlass_library.library.DataType.tf32
|
||||||
)
|
)
|
||||||
elif torch_dtype == torch.half:
|
elif torch_dtype == torch.half:
|
||||||
return cutlass_dtype == cutlass_library.DataType.f16
|
return cutlass_dtype == cutlass_library.library.DataType.f16
|
||||||
elif torch_dtype == torch.bfloat16:
|
elif torch_dtype == torch.bfloat16:
|
||||||
return cutlass_dtype == cutlass_library.DataType.bf16
|
return cutlass_dtype == cutlass_library.library.DataType.bf16
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -131,7 +131,6 @@ GEMM_ARGS_CUTLASS_3X = r"""
|
|||||||
};
|
};
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
|
GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
|
||||||
{
|
{
|
||||||
{ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename ThreadEpilogueOp::Params thread
|
{ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename ThreadEpilogueOp::Params thread
|
||||||
@ -183,7 +182,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
|
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
|
||||||
assert cutlass_utils.try_import_cutlass()
|
assert cutlass_utils.try_import_cutlass()
|
||||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||||
|
|
||||||
if torch_layout.stride[-1] == 1:
|
if torch_layout.stride[-1] == 1:
|
||||||
return cutlass_lib.LayoutType.RowMajor
|
return cutlass_lib.LayoutType.RowMajor
|
||||||
@ -197,7 +196,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
|
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
|
||||||
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
|
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
|
||||||
assert cutlass_utils.try_import_cutlass()
|
assert cutlass_utils.try_import_cutlass()
|
||||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||||
|
|
||||||
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
|
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
|
||||||
return cutlass_lib.LayoutType.ColumnMajor
|
return cutlass_lib.LayoutType.ColumnMajor
|
||||||
@ -220,7 +219,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def has_tma_epilogue(op) -> bool:
|
def has_tma_epilogue(op) -> bool:
|
||||||
assert cutlass_utils.try_import_cutlass()
|
assert cutlass_utils.try_import_cutlass()
|
||||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||||
|
|
||||||
result = False
|
result = False
|
||||||
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
||||||
@ -233,8 +232,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
|
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
assert cutlass_utils.try_import_cutlass()
|
assert cutlass_utils.try_import_cutlass()
|
||||||
import cutlass_gemm_operation as cutlass_gemm_op # type: ignore[import]
|
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||||
|
|
||||||
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
||||||
emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
|
emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
|
||||||
@ -291,7 +290,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
|
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||||
) -> "cutlass_gemm_op.GemmOperation": # type: ignore[name-defined]
|
) -> "cutlass_gemm_op.GemmOperation": # type: ignore[name-defined]
|
||||||
assert cutlass_utils.try_import_cutlass()
|
assert cutlass_utils.try_import_cutlass()
|
||||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||||
|
|
||||||
# Skip simt kernels
|
# Skip simt kernels
|
||||||
if (
|
if (
|
||||||
@ -306,7 +305,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
cutlass_lib.GemmKind.Universal3x,
|
cutlass_lib.GemmKind.Universal3x,
|
||||||
}:
|
}:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Filter ops by dtypes.
|
# Filter ops by dtypes.
|
||||||
X = self.input_nodes[0]
|
X = self.input_nodes[0]
|
||||||
W = self.input_nodes[1]
|
W = self.input_nodes[1]
|
||||||
@ -372,21 +370,22 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
|
|
||||||
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
|
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
|
||||||
assert cutlass_utils.try_import_cutlass()
|
assert cutlass_utils.try_import_cutlass()
|
||||||
import cutlass_gemm_operation as cutlass_gemm_op # type: ignore[import]
|
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||||
|
|
||||||
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
|
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
|
||||||
res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
|
res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
|
||||||
num_3x_ops = 0
|
num_3x_ops = 0
|
||||||
num_2x_ops = 0
|
num_2x_ops = 0
|
||||||
for op_list in ops.values():
|
for op_dict in ops.values():
|
||||||
for op in op_list:
|
for op_list in op_dict.values():
|
||||||
filter_res = self.filter_op(op)
|
for op in op_list:
|
||||||
if (
|
filter_res = self.filter_op(op)
|
||||||
filter_res is not None
|
if (
|
||||||
and res.get(filter_res.configuration_name(), None) is None
|
filter_res is not None
|
||||||
):
|
and res.get(filter_res.configuration_name(), None) is None
|
||||||
res[filter_res.configuration_name()] = filter_res
|
):
|
||||||
|
res[filter_res.configuration_name()] = filter_res
|
||||||
for op in res.values():
|
for op in res.values():
|
||||||
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
||||||
num_3x_ops += 1
|
num_3x_ops += 1
|
||||||
@ -481,7 +480,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
output_node: Optional[Buffer] = None,
|
output_node: Optional[Buffer] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
assert cutlass_utils.try_import_cutlass()
|
assert cutlass_utils.try_import_cutlass()
|
||||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||||
|
|
||||||
if output_node is not None:
|
if output_node is not None:
|
||||||
self.output_node = output_node
|
self.output_node = output_node
|
||||||
@ -523,6 +522,5 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
|||||||
instance_type=instance_type,
|
instance_type=instance_type,
|
||||||
input_reorder=self.input_reorder,
|
input_reorder=self.input_reorder,
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self._template_from_string(GEMM_TEMPLATE).render(**options)
|
res = self._template_from_string(GEMM_TEMPLATE).render(**options)
|
||||||
return res
|
return res
|
||||||
|
Reference in New Issue
Block a user