Compare commits

...

2 Commits

Author SHA1 Message Date
a29a888322 [Cutlass] Add exp and sigmoid activations
ghstack-source-id: 817cf573bf402f5942f7e18349e6fc4972f8ac0e
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162536
2025-09-09 15:26:26 -07:00
5eb16435e9 [Cutlass] Add tanh activation and test case for activations
ghstack-source-id: 4dc9301a0e822c87410fd78bee69562343a85017
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162535
2025-09-09 15:26:23 -07:00
2 changed files with 33 additions and 3 deletions

View File

@ -107,13 +107,15 @@ def _check_if_instances_equal(op1, op2) -> bool:
return True
un_ops_under_test = [torch.relu]
un_ops_under_test = [torch.relu, torch.tanh, torch.exp, torch.sigmoid]
bin_ops_under_test = [torch.add, torch.mul, torch.sub, torch.div]
evt_all_ops = parametrize(
"op", un_ops_under_test + bin_ops_under_test, name_fn=lambda f: f.__name__
)
evt_un_ops = parametrize("op", un_ops_under_test, name_fn=lambda f: f.__name__)
evt_bin_ops = parametrize("op", bin_ops_under_test, name_fn=lambda f: f.__name__)
evt_all_shapes = parametrize("shape", itertools.product([512, 1024], repeat=2))
@ -1976,6 +1978,30 @@ class TestCutlassBackend(TestCase):
)
torch.testing.assert_close(result, ref_result)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_un_ops
def test_evt_activations(self, op):
class TestModel(torch.nn.Module):
def forward(self, a, b, extra_args):
acc = a @ b
return acc, op(acc, *extra_args)
M = 1024
N = 512
a = torch.ones(M, N).cuda().half()
b = torch.ones(N, N).cuda().half().t()
extra_args = gen_args(op, (M, N))
model = TestModel().cuda()
result = torch.compile(model)(a, b, extra_args)
ref_result = model(a, b, extra_args)
self.assertEqual(
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 1
)
torch.testing.assert_close(result, ref_result)
@unittest.skipIf(not SM90OrLater, "need sm_90")
@use_evt_config
@evt_all_ops

View File

@ -88,7 +88,7 @@ class CutlassEVTOpsMixIn:
@staticmethod
def sigmoid(x0: str) -> str:
raise NotImplementedError("sigmoid is not supported in CUTLASS python evt")
return CutlassEVTOpsMixIn._prefix_un_op("sigmoid", x0)
@staticmethod
def sub(x0: str, x1: str) -> str:
@ -96,7 +96,11 @@ class CutlassEVTOpsMixIn:
@staticmethod
def tanh(x0: str) -> str:
raise NotImplementedError("tanh is not supported in CUTLASS python evt")
return CutlassEVTOpsMixIn._prefix_un_op("tanh", x0)
@staticmethod
def exp(x0: str) -> str:
return CutlassEVTOpsMixIn._prefix_un_op("exp", x0)
class MockCutlassHandler(CutlassEVTOpsMixIn, WrapperHandler):