Compare commits

...

1 Commits

Author SHA1 Message Date
5f30c3f55c Fix torch.cond HOP device in inductor 2025-11-07 12:13:17 -08:00
2 changed files with 32 additions and 6 deletions

View File

@ -20,9 +20,11 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
from torch.testing._internal.triton_utils import requires_gpu
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1, device=None):
result = []
device = inputs[0].device
if len(inputs) != 0:
device = inputs[0].device
assert device
# iterate over the cartesian product of predicate values
for values in itertools.product(*([possible_values] * num_to_prepend)):
prepended = [torch.tensor(v, device=device) for v in values]
@ -30,8 +32,8 @@ def _prepend_product_of_values(inputs, possible_values, num_to_prepend=1):
return result
def prepend_predicates(inputs, num_predicates=1):
return _prepend_product_of_values(inputs, [False, True], num_predicates)
def prepend_predicates(inputs, num_predicates=1, device=None):
return _prepend_product_of_values(inputs, [False, True], num_predicates, device)
def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
@ -308,7 +310,9 @@ class CondTests(TestCase):
torch._dynamo.mark_dynamic(inp, 0)
for inputs in input_sets:
for inputs_with_predicates in prepend_predicates(inputs, num_predicates):
for inputs_with_predicates in prepend_predicates(
inputs, num_predicates, device=device
):
cloned_inputs = [inp.clone() for inp in inputs_with_predicates]
result = model(*inputs_with_predicates)
result_compiled = compiled_model(*inputs_with_predicates)
@ -768,6 +772,26 @@ class CondTests(TestCase):
dynamic=dynamic,
)
@requires_gpu
def test_output_on_different_device(self):
class FactoryBranches(torch.nn.Module):
def forward(self, pred):
tensor = torch.cond(
pred,
lambda: torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32).to(
GPU_TYPE
),
lambda: torch.zeros(5, dtype=torch.float32).to(GPU_TYPE),
)
return tensor + 1
self._run_test(
model=FactoryBranches(),
inputs=(),
device="cpu", # device for predicate
dynamic=True,
)
class WhileLoopModels:
class Simple(torch.nn.Module):

View File

@ -8845,7 +8845,9 @@ class Conditional(ExternKernel):
outputs = [
MultiOutput(
FixedLayout(
device=device,
device=output.get_device()
if output.get_device() is not None
else device, # type: ignore[arg-type]
dtype=output.get_dtype(),
size=[Conditional._maybe_expr(sz) for sz in merged_output.size()],
stride=[