mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Cutlass 3.2.2 submodule upgrade] Adapt Inductor cutlass backend to Cutlass 3.2.2 (#112762)
The inductor cutlass backend was written against Cutlass version 3.1.x, there are some incompatible changes in Cutlass 3.2.2 which the Inductor cutlass backend needs to adapt to. Test plan: If third_party/cutlass is upgraded to Cutlass tag v3.2.2, several tests within test/inductor/test_max_autotune.py start to fail. With this diff applied, they pass again. Differential Revision: [D50986555](https://our.internmc.facebook.com/intern/diff/D50986555) Pull Request resolved: https://github.com/pytorch/pytorch/pull/112762 Approved by: https://github.com/ipiszy, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f10a2321d
commit
e36dba3a94
@ -479,6 +479,17 @@ if(LINUX)
|
||||
endif()
|
||||
|
||||
if(MSVC)
|
||||
# MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard
|
||||
# because MSVC is not a completely compliant implementation. This option forces MSVC to use the
|
||||
# appropriate value given the requested --std option. This fixes a compilation issue mismatch
|
||||
# between GCC/Clang and MSVC.
|
||||
#
|
||||
# See:
|
||||
# * https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170
|
||||
# * https://en.cppreference.com/w/cpp/preprocessor/replace#Predefined_macros
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus")
|
||||
|
||||
set(CMAKE_NINJA_CMCLDEPS_RC OFF)
|
||||
foreach(flag_var
|
||||
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
|
||||
|
2
third_party/cutlass
vendored
2
third_party/cutlass
vendored
Submodule third_party/cutlass updated: 6f47420213...44c704eae8
@ -17,66 +17,37 @@ from .cuda_env import get_cuda_arch, get_cuda_version
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
|
||||
for cutlass_module in cutlass_modules:
|
||||
content = content.replace(
|
||||
f"from {cutlass_module} import ", f"from cutlass_{cutlass_module} import "
|
||||
)
|
||||
return content
|
||||
|
||||
|
||||
def _gen_cutlass_file(
|
||||
file_name: str, cutlass_modules: List[str], src_dir: str, dst_dir: str
|
||||
) -> None:
|
||||
orig_full_path = os.path.abspath(os.path.join(src_dir, file_name))
|
||||
text = ""
|
||||
with open(orig_full_path) as f:
|
||||
text = f.read()
|
||||
text = _rename_cutlass_import(text, cutlass_modules)
|
||||
dst_full_path = os.path.abspath(
|
||||
os.path.join(
|
||||
dst_dir,
|
||||
f"cutlass_{file_name}" if file_name != "__init__.py" else file_name,
|
||||
)
|
||||
)
|
||||
with open(dst_full_path, "w") as f:
|
||||
f.write(text)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def try_import_cutlass() -> bool:
|
||||
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
|
||||
# This is a temporary hack to avoid CUTLASS module naming conflicts.
|
||||
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
|
||||
|
||||
cutlass_py_full_path = os.path.join(
|
||||
inductor_cuda_config.cutlass_dir, "tools/library/scripts"
|
||||
cutlass_py_full_path = os.path.abspath(
|
||||
os.path.join(inductor_cuda_config.cutlass_dir, "python/cutlass_library")
|
||||
)
|
||||
tmp_cutlass_py_full_path = os.path.abspath(
|
||||
os.path.join(cache_dir(), "torch_cutlass_script")
|
||||
os.path.join(cache_dir(), "torch_cutlass_library")
|
||||
)
|
||||
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
|
||||
|
||||
if os.path.isdir(cutlass_py_full_path):
|
||||
cutlass_file_names = [
|
||||
file_name
|
||||
for file_name in os.listdir(cutlass_py_full_path)
|
||||
if file_name.endswith(".py")
|
||||
]
|
||||
cutlass_module_names = [file_name[:-3] for file_name in cutlass_file_names]
|
||||
if not os.path.isdir(tmp_cutlass_py_full_path):
|
||||
os.mkdir(tmp_cutlass_py_full_path)
|
||||
for file_name in cutlass_file_names:
|
||||
_gen_cutlass_file(
|
||||
file_name,
|
||||
cutlass_module_names,
|
||||
cutlass_py_full_path,
|
||||
tmp_cutlass_py_full_path,
|
||||
)
|
||||
sys.path.append(tmp_cutlass_py_full_path)
|
||||
if tmp_cutlass_py_full_path not in sys.path:
|
||||
if os.path.exists(dst_link):
|
||||
assert os.path.islink(
|
||||
dst_link
|
||||
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
||||
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
|
||||
cutlass_py_full_path
|
||||
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
|
||||
else:
|
||||
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
|
||||
os.symlink(cutlass_py_full_path, dst_link)
|
||||
sys.path.append(tmp_cutlass_py_full_path)
|
||||
try:
|
||||
import cutlass_generator # type: ignore[import] # noqa: F401
|
||||
import cutlass_library # type: ignore[import] # noqa: F401
|
||||
import cutlass_manifest # type: ignore[import] # noqa: F401
|
||||
import cutlass_library.generator # type: ignore[import] # noqa: F401
|
||||
import cutlass_library.library # type: ignore[import] # noqa: F401
|
||||
import cutlass_library.manifest # type: ignore[import] # noqa: F401
|
||||
|
||||
return True
|
||||
|
||||
@ -136,18 +107,14 @@ class CUTLASSArgs:
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def gen_ops() -> List[Any]:
|
||||
"""
|
||||
Generates all supported CUTLASS operations.
|
||||
"""
|
||||
def _gen_ops_cached(arch, version) -> List[Any]:
|
||||
# Note: Cache needs to be specific for cuda architecture and version
|
||||
|
||||
# Import cutlass python scripts.
|
||||
assert try_import_cutlass()
|
||||
import cutlass_generator # type: ignore[import]
|
||||
import cutlass_manifest # type: ignore[import]
|
||||
import cutlass_library.generator as cutlass_generator # type: ignore[import]
|
||||
import cutlass_library.manifest as cutlass_manifest # type: ignore[import]
|
||||
|
||||
arch = get_cuda_arch()
|
||||
version = get_cuda_version()
|
||||
if arch is None or version is None:
|
||||
log.error(
|
||||
"Cannot detect cuda arch %s or cuda version %s. "
|
||||
@ -172,13 +139,21 @@ def gen_ops() -> List[Any]:
|
||||
raise NotImplementedError(
|
||||
"Arch " + arch + " is not supported by current cutlass lib."
|
||||
) from e
|
||||
|
||||
return manifest.operations
|
||||
|
||||
|
||||
def gen_ops() -> List[Any]:
|
||||
"""
|
||||
Generates all supported CUTLASS operations.
|
||||
"""
|
||||
arch = get_cuda_arch()
|
||||
version = get_cuda_version()
|
||||
return _gen_ops_cached(arch, version)
|
||||
|
||||
|
||||
def dtype_match(
|
||||
torch_dtype: torch.dtype,
|
||||
cutlass_dtype: "cutlass_library.DataType", # type: ignore[name-defined]
|
||||
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined]
|
||||
) -> bool:
|
||||
# Import cutlass python scripts.
|
||||
assert try_import_cutlass()
|
||||
@ -186,13 +161,13 @@ def dtype_match(
|
||||
|
||||
if torch_dtype == torch.float:
|
||||
return (
|
||||
cutlass_dtype == cutlass_library.DataType.f32
|
||||
or cutlass_dtype == cutlass_library.DataType.tf32
|
||||
cutlass_dtype == cutlass_library.library.DataType.f32
|
||||
or cutlass_dtype == cutlass_library.library.DataType.tf32
|
||||
)
|
||||
elif torch_dtype == torch.half:
|
||||
return cutlass_dtype == cutlass_library.DataType.f16
|
||||
return cutlass_dtype == cutlass_library.library.DataType.f16
|
||||
elif torch_dtype == torch.bfloat16:
|
||||
return cutlass_dtype == cutlass_library.DataType.bf16
|
||||
return cutlass_dtype == cutlass_library.library.DataType.bf16
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -131,7 +131,6 @@ GEMM_ARGS_CUTLASS_3X = r"""
|
||||
};
|
||||
"""
|
||||
|
||||
|
||||
GEMM_ARGS_CUTLASS_3X_EPILOGUE = r"""
|
||||
{
|
||||
{ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename ThreadEpilogueOp::Params thread
|
||||
@ -183,7 +182,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
@staticmethod
|
||||
def cutlass_layout(torch_layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined]
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
if torch_layout.stride[-1] == 1:
|
||||
return cutlass_lib.LayoutType.RowMajor
|
||||
@ -197,7 +196,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined]
|
||||
) -> "cutlass_lib.LayoutType": # type: ignore[name-defined]
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
if cutlass_layout == cutlass_lib.LayoutType.RowMajor:
|
||||
return cutlass_lib.LayoutType.ColumnMajor
|
||||
@ -220,7 +219,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
@staticmethod
|
||||
def has_tma_epilogue(op) -> bool:
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
result = False
|
||||
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
||||
@ -233,8 +232,8 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||
) -> Tuple[str, str]:
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
||||
emitter = cutlass_gemm_op.EmitGemmUniversal3xInstance()
|
||||
@ -291,7 +290,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
op: "cutlass_gemm_op.GemmOperation", # type: ignore[name-defined]
|
||||
) -> "cutlass_gemm_op.GemmOperation": # type: ignore[name-defined]
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
# Skip simt kernels
|
||||
if (
|
||||
@ -306,7 +305,6 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
cutlass_lib.GemmKind.Universal3x,
|
||||
}:
|
||||
return None
|
||||
|
||||
# Filter ops by dtypes.
|
||||
X = self.input_nodes[0]
|
||||
W = self.input_nodes[1]
|
||||
@ -372,21 +370,22 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
|
||||
def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined]
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.gemm_operation as cutlass_gemm_op # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
ops = cutlass_utils.gen_ops()[cutlass_lib.OperationKind.Gemm]
|
||||
res: Dict[str, cutlass_gemm_op.GemmOperation] = dict()
|
||||
num_3x_ops = 0
|
||||
num_2x_ops = 0
|
||||
for op_list in ops.values():
|
||||
for op in op_list:
|
||||
filter_res = self.filter_op(op)
|
||||
if (
|
||||
filter_res is not None
|
||||
and res.get(filter_res.configuration_name(), None) is None
|
||||
):
|
||||
res[filter_res.configuration_name()] = filter_res
|
||||
for op_dict in ops.values():
|
||||
for op_list in op_dict.values():
|
||||
for op in op_list:
|
||||
filter_res = self.filter_op(op)
|
||||
if (
|
||||
filter_res is not None
|
||||
and res.get(filter_res.configuration_name(), None) is None
|
||||
):
|
||||
res[filter_res.configuration_name()] = filter_res
|
||||
for op in res.values():
|
||||
if op.gemm_kind == cutlass_lib.GemmKind.Universal3x:
|
||||
num_3x_ops += 1
|
||||
@ -481,7 +480,7 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
output_node: Optional[Buffer] = None,
|
||||
) -> str:
|
||||
assert cutlass_utils.try_import_cutlass()
|
||||
import cutlass_library as cutlass_lib # type: ignore[import]
|
||||
import cutlass_library.library as cutlass_lib # type: ignore[import]
|
||||
|
||||
if output_node is not None:
|
||||
self.output_node = output_node
|
||||
@ -523,6 +522,5 @@ class CUTLASSGemmTemplate(CUTLASSTemplate):
|
||||
instance_type=instance_type,
|
||||
input_reorder=self.input_reorder,
|
||||
)
|
||||
|
||||
res = self._template_from_string(GEMM_TEMPLATE).render(**options)
|
||||
return res
|
||||
|
Reference in New Issue
Block a user