mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable jit fusion on ROCm (#22872)
Summary: As of ROCm 2.6, we support hiprtc - the HIP runtime compilation API. Enable the jit fusion feature depending on the existence of such an API. This entails * new hipification rules for API_RTC * add hiprtc APIs to the shim loader * update cmake infrastructure to find the hiprtc library (it is part of the HIP package) * enabling of unit tests in the jit_fuser test set * special casing in resource strings for HIP - the typedefs CUDA requires would be redundant * for now disable the occupancy calculation we do not support yet and hard-code Thanks to t-vi for working with me on getting this integration done! Pull Request resolved: https://github.com/pytorch/pytorch/pull/22872 Differential Revision: D17207425 Pulled By: bddppq fbshipit-source-id: 93409f3051ad0ea06afacc2239fd6c402152debe
This commit is contained in:
committed by
Facebook Github Bot
parent
82c8949a9d
commit
9c5a899773
@ -43,7 +43,6 @@ class TestFuser(JitTestCase):
|
||||
self._test_fused_abs()
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_abs_cuda(self):
|
||||
self._test_fused_abs(device="cuda")
|
||||
|
||||
@ -75,7 +74,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_broadcast_cuda(self):
|
||||
def scaleshift(x, scale, shift):
|
||||
return x * scale + shift
|
||||
@ -124,7 +122,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertEqual(grads_half, fusion_grads)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_checks_cat_inputs(self):
|
||||
# We shouldn't treat cat nodes as broadcasting. All their inputs
|
||||
# need to be checked for having the same map size, before we can
|
||||
@ -142,7 +139,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(f.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
||||
@skipIfRocm
|
||||
def test_chunk_cuda(self):
|
||||
def fn(x):
|
||||
a, b, c = x.chunk(3, 1)
|
||||
@ -195,7 +191,6 @@ class TestFuser(JitTestCase):
|
||||
return self._test_chunk_correctness(self, 'cuda')
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_chunk_distributes_cuda(self):
|
||||
def f(x, y):
|
||||
z1, z2 = (x + y).chunk(2, dim=1)
|
||||
@ -210,7 +205,6 @@ class TestFuser(JitTestCase):
|
||||
.check_count('ConstantChunk', 2, exactly=True).run(str(graph))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_chunk_motion_deduplicates_inputs(self):
|
||||
def func1(x):
|
||||
z = x * x
|
||||
@ -233,7 +227,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertEqual(len(list(fusion_group.inputs())), 1)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
||||
@skipIfRocm
|
||||
def test_chunk_multiple_cuda(self):
|
||||
# The arguments are intentionally used out of order as a test to see
|
||||
# if the fusion compiler adds extra args in the correct order
|
||||
@ -254,7 +247,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(ge.graph_for(*inputs))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_clamp(self):
|
||||
def func2(a, b):
|
||||
return torch.clamp(a + b, min=0, max=2)
|
||||
@ -284,7 +276,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(graph, except_for={'aten::Float'})
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_dropout(self):
|
||||
def func(x):
|
||||
x = torch.nn.functional.dropout(x)
|
||||
@ -298,7 +289,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'})
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_comparison_eq_ne(self):
|
||||
def f(x, y):
|
||||
mask = (x == 0).type_as(x)
|
||||
@ -322,7 +312,6 @@ class TestFuser(JitTestCase):
|
||||
return z
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_comparison_gt_lt_cuda(self):
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
@ -331,7 +320,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_comparison_ge_le_cuda(self):
|
||||
def f(x, y):
|
||||
mask = (x >= 0).type_as(x)
|
||||
@ -351,7 +339,6 @@ class TestFuser(JitTestCase):
|
||||
"aten::_size_if_not_equal"))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_addcmul_cuda(self):
|
||||
t = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
||||
t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
|
||||
@ -394,7 +381,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(graph)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_concat_cuda(self):
|
||||
hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
|
||||
cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
|
||||
@ -408,7 +394,6 @@ class TestFuser(JitTestCase):
|
||||
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_concat_invariant_cuda(self):
|
||||
# Invariant: the output of prim::FusedConcat may
|
||||
# not be an input to any node inside the FusionGroup.
|
||||
@ -431,7 +416,6 @@ class TestFuser(JitTestCase):
|
||||
return (x + .5 * y).exp()
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_exp_cuda(self):
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
@ -440,7 +424,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
@_inline_everything
|
||||
def test_fuse_decompose_normalization(self):
|
||||
class ResLike(torch.jit.ScriptModule):
|
||||
@ -495,7 +478,6 @@ class TestFuser(JitTestCase):
|
||||
['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add'])
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_threshold(self):
|
||||
def f(x):
|
||||
return torch.threshold(x, 0, -10) + x + x + x
|
||||
@ -507,7 +489,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(scripted.graph_for(x))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_scalar_arg_cuda(self):
|
||||
def fn_test_scalar_arg(x, p):
|
||||
# type: (Tensor, float) -> Tensor
|
||||
@ -651,7 +632,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_lstm_cuda(self):
|
||||
inputs = get_lstm_inputs('cuda', training=True)
|
||||
module = self.checkScript(LSTMCellS, inputs)
|
||||
@ -670,7 +650,6 @@ class TestFuser(JitTestCase):
|
||||
"aten::_grad_sum_to_size"))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_lstm_concat_cuda(self):
|
||||
inputs = get_lstm_inputs('cuda')
|
||||
ge = self.checkTrace(LSTMCellC, inputs)
|
||||
@ -678,7 +657,6 @@ class TestFuser(JitTestCase):
|
||||
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_lstm_gates_permutations_cuda(self):
|
||||
# lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
|
||||
# Test that any permutation of this will still result in one FusionGroup.
|
||||
@ -702,7 +680,6 @@ class TestFuser(JitTestCase):
|
||||
|
||||
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_lstm_traced_cuda(self):
|
||||
inputs = get_lstm_inputs('cuda')
|
||||
ge = self.checkTrace(LSTMCellF, inputs)
|
||||
@ -730,7 +707,6 @@ class TestFuser(JitTestCase):
|
||||
raise
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_milstm_cuda(self):
|
||||
inputs = get_milstm_inputs('cuda', training=True)
|
||||
module = self.checkScript(MiLSTMCell, inputs)
|
||||
@ -743,7 +719,6 @@ class TestFuser(JitTestCase):
|
||||
(hy + cy).sum().backward()
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_rand_cuda(self):
|
||||
class M(torch.jit.ScriptModule):
|
||||
__constants__ = ['d']
|
||||
@ -772,7 +747,6 @@ class TestFuser(JitTestCase):
|
||||
return F.relu(x + .5 * y)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_relu_cuda(self):
|
||||
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
||||
@ -781,7 +755,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_erf_cuda(self):
|
||||
def fn_test_erf(x):
|
||||
return F.relu(torch.erf(x) - torch.erfc(x))
|
||||
@ -794,7 +767,6 @@ class TestFuser(JitTestCase):
|
||||
"aten::_size_if_not_equal"))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_rand_broadcast_cuda(self):
|
||||
def fn_test_rand(x, y):
|
||||
r = torch.rand_like(y)
|
||||
@ -827,7 +799,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_small_constant_cuda(self):
|
||||
def fn_test_small_constant(x, y):
|
||||
return (1e-8 * x + 5e-9 * y) * 1e8
|
||||
@ -838,7 +809,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(ge.graph_for(x, y))
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_tensor_scalar_ops_cuda(self):
|
||||
def should_fuse(x):
|
||||
z = 3.
|
||||
@ -883,7 +853,6 @@ class TestFuser(JitTestCase):
|
||||
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
||||
@skipIfRocm
|
||||
def test_grad_sum_to_size_elimination(self):
|
||||
|
||||
def my_broadcasted_cell(a, b, c):
|
||||
|
Reference in New Issue
Block a user