mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AOTI] Handle empty input args (#114682)
Summary: When the model takes no inputs, AOTInductor relies on checking weights to figure out which device to compile the model into. Currently recording buffer device type happens too late, and this PR fixes that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/114682 Approved by: https://github.com/chenyang78
This commit is contained in:
committed by
PyTorch MergeBot
parent
3d8c174069
commit
e06bff8bbe
@ -1465,6 +1465,20 @@ class AOTInductorTestsTemplate:
|
||||
inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
|
||||
self.check_model(Model(4), inputs)
|
||||
|
||||
def test_no_args(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, m, n):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.randn(m, n),
|
||||
)
|
||||
self.alpha = torch.nn.Parameter(torch.randn(m, n))
|
||||
|
||||
def forward(self):
|
||||
return self.weight * self.alpha
|
||||
|
||||
self.check_model(Model(6, 4), ())
|
||||
|
||||
|
||||
common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)
|
||||
|
||||
|
@ -472,9 +472,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self._warned_fallback.add(name)
|
||||
perf_hint_log.info("Using FallbackKernel: %s", name)
|
||||
|
||||
def add_device_idx(self, idx: Optional[int]):
|
||||
if idx is not None:
|
||||
self.device_idxs.add(idx)
|
||||
def add_device_info(self, device: torch.device):
|
||||
self.device_types.add(device.type)
|
||||
if device.index is not None:
|
||||
self.device_idxs.add(device.index)
|
||||
|
||||
@property
|
||||
def fake_mode(self):
|
||||
@ -521,6 +522,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
name = f"buf{len(self.buffers)}"
|
||||
self.buffers.append(buffer)
|
||||
self.name_to_buffer[name] = buffer
|
||||
# Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
|
||||
if not isinstance(buffer, ir.ComputedBuffer) or not buffer.is_zero_elements():
|
||||
self.add_device_info(buffer.get_device())
|
||||
return name
|
||||
|
||||
def register_list(self, buffer_names: List[str]):
|
||||
@ -645,8 +649,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
)
|
||||
self.graph_inputs[target] = tensor
|
||||
self.graph_inputs_original[target] = tensor.data.data
|
||||
self.device_types.add(example.device.type)
|
||||
self.add_device_idx(example.device.index)
|
||||
self.add_device_info(example.device)
|
||||
return tensor
|
||||
|
||||
def call_function(self, target, args, kwargs):
|
||||
@ -979,10 +982,6 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
return
|
||||
|
||||
device_types = self.device_types.copy()
|
||||
# In terms of some operations that don't have input tensors, we need to
|
||||
# check the device of the buffers.
|
||||
for buffer in self.buffers:
|
||||
device_types.add(buffer.get_device().type)
|
||||
device_types.discard("cpu")
|
||||
# TODO(Eikan): Only support mixing cpu and other device now.
|
||||
assert len(device_types) <= 1, "Does not support mixing {}".format(
|
||||
@ -1015,7 +1014,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
else:
|
||||
assert isinstance(
|
||||
x, torch.Tensor
|
||||
), "Unknown type when creating real inputs"
|
||||
), "Unknown type when creating real inputs" + str(type(x))
|
||||
return x
|
||||
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
|
@ -4160,10 +4160,8 @@ class DeviceCopy(ExternKernelOut):
|
||||
):
|
||||
return x.constant_to_device(device)
|
||||
|
||||
V.graph.device_types.add(device.type)
|
||||
V.graph.add_device_idx(device.index)
|
||||
V.graph.device_types.add(x.get_device().type)
|
||||
V.graph.add_device_idx(x.get_device().index)
|
||||
V.graph.add_device_info(device)
|
||||
V.graph.add_device_info(x.get_device())
|
||||
|
||||
developer_warning("DeviceCopy in input program")
|
||||
return DeviceCopy(
|
||||
|
@ -2130,8 +2130,7 @@ class Scheduler:
|
||||
assert (
|
||||
device.type != "cuda" or device.index is not None
|
||||
), f"{device} should have been normalized in lowering"
|
||||
V.graph.device_types.add(device.type)
|
||||
V.graph.add_device_idx(device.index)
|
||||
V.graph.add_device_info(device)
|
||||
|
||||
device_scheduling = get_scheduling_for_device(device.type)
|
||||
if device_scheduling is None:
|
||||
|
Reference in New Issue
Block a user