Files
pytorch/test/inductor/test_kernel_optimization.py
Catherine Lee 561193e5f2 [CI][testing] Use 3 processes for testing on sm89 and sm90 jobs (#158691)
3 procs were used for sm86, but we switched to sm89 and the check failed so it switched back to 2

sm90 is H100, but idk what unittests we have running there, but I assume they also have a lot of memory

They use larger runners, which have more GPU memory, so its usually ok.  I think it's ~22GB -> 10GB per proc if 2, 6GB per proc if 3 (cuda context maybe 1GB)

I've applied skips to the ones that OOMed

Time decreases from ~2.7hr per test job -> ~2hr

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158691
Approved by: https://github.com/huydhn
2025-07-25 15:26:29 +00:00

95 lines
3.4 KiB
Python

# Owner(s): ["module: inductor"]
import torch
import torch._inductor
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import serialTest
from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu
class TestEinsumtoPointwise(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
input: torch.Tensor,
weights: torch.Tensor,
bias: torch.Tensor,
input2: torch.Tensor,
weights2: torch.Tensor,
bias2: torch.Tensor,
) -> torch.Tensor:
output = torch.functional.einsum("bni, nio -> bno", input, weights)
add1 = output.add(bias)
output2 = torch.functional.einsum("bni, bnio -> bno", input2, weights2)
add2 = output2 + bias2
return add1 + add2
class TestKernelOptimization(TestCase):
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
return False
for key1 in ref_dict.keys():
key2 = "_orig_mod." + key1
assert key2 in res_dict, f"{key1} does not exist in traced module"
if not torch.allclose(ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol):
return False
return True
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
ref = module(*input)
res = traced(*input)
self.assertEqual(ref, res, rtol=rtol, atol=atol)
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
ref_params = dict(module.named_parameters())
res_params = dict(traced.named_parameters())
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
ref_grad = {key: param.grad for key, param in module.named_parameters()}
res_grad = {key: param.grad for key, param in traced.named_parameters()}
self.assertTrue(
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
)
@requires_gpu()
@torch._inductor.config.patch(
pre_grad_fusion_options={
"einsum_to_pointwise_pass": {},
},
post_grad_fusion_options={},
)
@serialTest() # Needs slightly more memory on GPUs
def test_einsum_to_pointwise(self):
counters.clear()
module = TestEinsumtoPointwise().to(GPU_TYPE)
input = [
torch.randn(4096, 9, 512, device=GPU_TYPE, requires_grad=True),
torch.randn(9, 512, 96, device=GPU_TYPE, requires_grad=True),
torch.randn(9, 96, device=GPU_TYPE, requires_grad=True),
torch.randn(4096, 9, 160, device=GPU_TYPE, requires_grad=True),
torch.randn(4096, 9, 160, 96, device=GPU_TYPE, requires_grad=True),
torch.randn(4096, 9, 96, device=GPU_TYPE, requires_grad=True),
]
traced = torch.compile(module)
ref = module(*input)
res = traced(*input)
ref.sum().backward()
res.sum().backward()
self.compare_pred(module, traced, input)
self.compare_parameters(module, traced)
self.compare_gradients(module, traced)
self.assertEqual(
counters["inductor"]["einsum_to_pointwise_pass"],
1,
)
counters.clear()
if __name__ == "__main__":
run_tests()