Compare commits

...

1 Commits

Author SHA1 Message Date
cf1a2abf35 force kwarg=copy on .to in proxy_call in export 2025-07-24 16:48:22 -07:00
4 changed files with 51 additions and 2 deletions

View File

@ -6832,6 +6832,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
ops.append(node.target)
self.assertEqual(len(ops), 1)
def test_device_to_agnostic(self):
class Module(torch.nn.Module):
def forward(self, x):
z = x + 1
y = z.to("cpu")
z.add_(1)
return y
ep = export(Module(), (torch.tensor(1, device="cpu"),))
for node in ep.graph.nodes:
if node.target == torch.ops.aten.add_.Tensor:
self.assertEqual(node.args[0].name, "add")
def test_device_to_mutation(self):
class Module(torch.nn.Module):
def forward(self, x):

View File

@ -1060,7 +1060,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
frame.f_lineno,
frame.f_code.co_name,
)
# could also intercept here, but writing a lot more code needed
func, args, kwargs = self._override(func, args, kwargs)
try:
return func(*args, **kwargs)

View File

@ -1673,11 +1673,13 @@ def _export_to_aten_ir_make_fx(
k.__getattribute__ = old_getattr # type: ignore[method-assign, attr-defined]
with ctx, override_getattribute_for_subclasses(flat_args):
torch.fx.experimental.proxy_tensor.in_export = True
gm = make_fx(
wrapped_fn,
record_module_stack=True,
pre_dispatch=True,
)(*flat_args)
torch.fx.experimental.proxy_tensor.in_export = False
if non_strict_root is not None:
input_names = _graph_input_names(gm)

View File

@ -817,6 +817,34 @@ def _maybe_record_pointwise_barrier(
last_node.meta["low_precision_pointwise_barrier"] = True
in_export = False
@contextmanager
def _force_copy_on_aten_to_context(
func: Any, kwargs: dict[str, Any]
) -> Generator[None, None, None]:
"""
Context manager for handling torch.ops.aten.to.dtype_layout operations.
When used with torch.ops.aten.to.dtype_layout, this context manager temporarily
sets kwargs["copy"] = True for the duration of the operation, and then restores
the original value afterward.
"""
if in_export and func == torch.ops.aten.to.dtype_layout:
previous_copy_kwarg = kwargs.get("copy", None)
kwargs["copy"] = True
try:
yield
finally:
if previous_copy_kwarg is None:
del kwargs["copy"]
else:
kwargs["copy"] = previous_copy_kwarg
else:
yield
def proxy_call(
proxy_mode: ProxyTorchDispatchMode,
func: OpOverload,
@ -970,7 +998,10 @@ def proxy_call(
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
)
with _enable_thunkify(proxy_mode.tracer):
with (
_enable_thunkify(proxy_mode.tracer),
_force_copy_on_aten_to_context(func, kwargs),
):
out = func(*args, **kwargs)
# In some circumstances, we will be tracing in a situation where a tensor
@ -1433,6 +1464,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
return node
# Don't actually run the function! We just want to trace the calls
# into a graph. We don't actually want to change global autograd state.
# Can't intercept here, too late
return func(*args, **kwargs)
@ -1488,6 +1520,8 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
if func in (prim.device.default,):
return func(*args, **kwargs)
# if func == torch.ops.aten.add_.Tensor:
# breakpoint()
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
def __enter__(self) -> Self: