[AOTI] Handle empty input args (#114682)

Summary: When the model takes no inputs, AOTInductor relies on checking weights to figure out which device to compile the model into. Currently recording buffer device type happens too late, and this PR fixes that.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114682
Approved by: https://github.com/chenyang78
This commit is contained in:
Bin Bao
2023-12-04 19:56:18 -08:00
committed by PyTorch MergeBot
parent 3d8c174069
commit e06bff8bbe
4 changed files with 26 additions and 16 deletions

View File

@ -1465,6 +1465,20 @@ class AOTInductorTestsTemplate:
inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
self.check_model(Model(4), inputs)
def test_no_args(self):
class Model(torch.nn.Module):
def __init__(self, m, n):
super().__init__()
self.weight = torch.nn.Parameter(
torch.randn(m, n),
)
self.alpha = torch.nn.Parameter(torch.randn(m, n))
def forward(self):
return self.weight * self.alpha
self.check_model(Model(6, 4), ())
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)

View File

@ -472,9 +472,10 @@ class GraphLowering(torch.fx.Interpreter):
self._warned_fallback.add(name)
perf_hint_log.info("Using FallbackKernel: %s", name)
def add_device_idx(self, idx: Optional[int]):
if idx is not None:
self.device_idxs.add(idx)
def add_device_info(self, device: torch.device):
self.device_types.add(device.type)
if device.index is not None:
self.device_idxs.add(device.index)
@property
def fake_mode(self):
@ -521,6 +522,9 @@ class GraphLowering(torch.fx.Interpreter):
name = f"buf{len(self.buffers)}"
self.buffers.append(buffer)
self.name_to_buffer[name] = buffer
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements():
self.add_device_info(buffer.get_device())
return name
def register_list(self, buffer_names: List[str]):
@ -645,8 +649,7 @@ class GraphLowering(torch.fx.Interpreter):
)
self.graph_inputs[target] = tensor
self.graph_inputs_original[target] = tensor.data.data
self.device_types.add(example.device.type)
self.add_device_idx(example.device.index)
self.add_device_info(example.device)
return tensor
def call_function(self, target, args, kwargs):
@ -979,10 +982,6 @@ class GraphLowering(torch.fx.Interpreter):
return
device_types = self.device_types.copy()
# In terms of some operations that don't have input tensors, we need to
# check the device of the buffers.
for buffer in self.buffers:
device_types.add(buffer.get_device().type)
device_types.discard("cpu")
# TODO(Eikan): Only support mixing cpu and other device now.
assert len(device_types) <= 1, "Does not support mixing {}".format(
@ -1015,7 +1014,7 @@ class GraphLowering(torch.fx.Interpreter):
else:
assert isinstance(
x, torch.Tensor
), "Unknown type when creating real inputs"
), "Unknown type when creating real inputs" + str(type(x))
return x
with torch.utils._python_dispatch._disable_current_modes():

View File

@ -4160,10 +4160,8 @@ class DeviceCopy(ExternKernelOut):
):
return x.constant_to_device(device)
V.graph.device_types.add(device.type)
V.graph.add_device_idx(device.index)
V.graph.device_types.add(x.get_device().type)
V.graph.add_device_idx(x.get_device().index)
V.graph.add_device_info(device)
V.graph.add_device_info(x.get_device())
developer_warning("DeviceCopy in input program")
return DeviceCopy(

View File

@ -2130,8 +2130,7 @@ class Scheduler:
assert (
device.type != "cuda" or device.index is not None
), f"{device} should have been normalized in lowering"
V.graph.device_types.add(device.type)
V.graph.add_device_idx(device.index)
V.graph.add_device_info(device)
device_scheduling = get_scheduling_for_device(device.type)
if device_scheduling is None: