mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[submodule] CUTLASS upgrade to 4.2.0 and change cutlass to cutlass_cppgen (#163092)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163092 Approved by: https://github.com/drisspg, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b7aed89d8
commit
a81a2e54ed
@ -257,7 +257,7 @@ class TestCutlassBackend(TestCase):
|
||||
if config.is_fbcode():
|
||||
import python_cutlass
|
||||
else:
|
||||
import cutlass as python_cutlass # noqa: F401
|
||||
import cutlass_cppgen as python_cutlass # noqa: F401
|
||||
import cutlass_library # noqa: F401
|
||||
|
||||
def test_cutlass_key(self):
|
||||
|
@ -36,7 +36,7 @@ if try_import_cutlass():
|
||||
if config.is_fbcode():
|
||||
import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
else:
|
||||
import cutlass as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
CutlassTensor = python_cutlass.backend.evt.ir.tensor.Tensor
|
||||
|
||||
BIAS_CODE = """def example_epilogue(accum, C, aux, bias):
|
||||
|
2
third_party/cutlass
vendored
2
third_party/cutlass
vendored
Submodule third_party/cutlass updated: e51efbfe18...57e3cfb47a
@ -38,7 +38,7 @@ if try_import_cutlass():
|
||||
if config.is_fbcode():
|
||||
import python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
else:
|
||||
import cutlass as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-untyped, import-not-found] # noqa: F401
|
||||
|
||||
from torch._inductor.codegen.cuda import cuda_env
|
||||
from torch._inductor.utils import IndentedBuffer
|
||||
@ -174,7 +174,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
def is_nested_visitor_type(t: type) -> bool:
|
||||
return ".".join([t.__module__, t.__qualname__]) in {
|
||||
"python_cutlass.backend.c_types.visitor_factory.<locals>.VisitorType",
|
||||
"cutlass.backend.c_types.visitor_factory.<locals>.VisitorType",
|
||||
"cutlass_cppgen.backend.c_types.visitor_factory.<locals>.VisitorType",
|
||||
}
|
||||
|
||||
buffer = IndentedBuffer()
|
||||
@ -235,7 +235,7 @@ non-contiguous layout, received stride: {stride} and shape: {shape}"
|
||||
# Once again, need to check for local class type for stride tuple
|
||||
if str(arg_ty) in {
|
||||
"<class 'python_cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>",
|
||||
"<class 'cutlass.backend.c_types.tuple_factory_.<locals>.TupleType'>",
|
||||
"<class 'cutlass_cppgen.backend.c_types.tuple_factory_.<locals>.TupleType'>",
|
||||
}:
|
||||
DEFAULT_STRIDE_LEN = 3
|
||||
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
|
||||
|
@ -43,7 +43,7 @@ def move_cutlass_compiled_cache() -> None:
|
||||
if config.is_fbcode():
|
||||
import python_cutlass # type: ignore[import-not-found]
|
||||
else:
|
||||
import cutlass as python_cutlass # type: ignore[import-not-found] # noqa: F401
|
||||
import cutlass_cppgen as python_cutlass # type: ignore[import-not-found] # noqa: F401
|
||||
|
||||
# Check if the CACHE_FILE attribute exists in python_cutlass and if the file exists
|
||||
if not hasattr(python_cutlass, "CACHE_FILE") or not os.path.exists(
|
||||
@ -118,7 +118,7 @@ def try_import_cutlass() -> bool:
|
||||
tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass"))
|
||||
|
||||
dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library")
|
||||
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass")
|
||||
dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass_cppgen")
|
||||
dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute")
|
||||
|
||||
# mock modules to import cutlass
|
||||
@ -156,7 +156,7 @@ def try_import_cutlass() -> bool:
|
||||
)
|
||||
|
||||
try:
|
||||
import cutlass # noqa: F401, F811
|
||||
import cutlass_cppgen # noqa: F401, F811
|
||||
import cutlass_library.generator # noqa: F401
|
||||
import cutlass_library.library # noqa: F401
|
||||
import cutlass_library.manifest # noqa: F401
|
||||
|
Reference in New Issue
Block a user