Compare commits

...

10 Commits

Author SHA1 Message Date
e6c495c440 Add check for numel 2025-11-18 15:06:41 -08:00
254ea52c80 Add x._ComptimeVar__variable.requires_grad check 2025-11-18 14:08:45 -08:00
e898efcdf2 Check for first TensorVariable, update test to use backward 2025-11-18 10:54:51 -08:00
726cde9ae1 Fullgraph=True 2025-11-18 09:10:15 -08:00
df437b7557 Add check for requires grad during compile 2025-11-17 23:42:37 -08:00
b2eb396c92 Simplify test 2025-11-17 13:32:59 -08:00
279384d38e Remove manual tests 2025-11-17 13:04:15 -08:00
7872f897a7 Add TestTensorMetaProp(TestCase) 2025-11-17 01:10:40 -08:00
ee7c3424db Cursor attempt 2025-11-11 14:19:59 -08:00
de016f8439 Add test_add_tensor_prop 2025-11-11 13:44:24 -08:00
2 changed files with 170 additions and 20 deletions

View File

@ -15,6 +15,7 @@ from importlib import import_module
import torch
import torch._prims as prims
import torch.utils._pytree as pytree
from torch._dynamo.comptime import comptime, ComptimeContext
from torch._prims.context import TorchRefsMode
from torch._prims_common.wrappers import _maybe_remove_out_wrapper
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
@ -3029,6 +3030,128 @@ class TestForwardADWithScalars(TestCase):
)
@unMarkDynamoStrictTest
class TestTensorMetaProp(TestCase):
"""
Test that inplace operations correctly propagate tensor metadata during Dynamo tracing.
"""
@ops([op for op in op_db if op.get_inplace() is not None])
def test_inplace_ops_propagate_requires_grad_metadata(self, device, dtype, op):
"""
Test that inplace ops from OpInfo propagate requires_grad correctly.
This test ensures that when an inplace operation is performed on a tensor
without requires_grad using an argument with requires_grad=True, the metadata
is correctly propagated in both eager and compiled modes.
This is critical because if metadata is traced incorrectly, code that branches
on requires_grad (like custom autograd functions) will take the wrong path,
leading to silent incorrectness.
"""
inplace_op = op.get_inplace()
if inplace_op is None:
self.skipTest("No inplace variant for this op")
samples = list(op.sample_inputs(device, dtype, requires_grad=False))
class CustomAutograd(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x * 2
@staticmethod
def backward(ctx, grad_out):
# Return an obviously wrong gradient (fixed value) to detect
# when composite implicit autograd is used vs custom backward
(x,) = ctx.saved_tensors
return torch.full_like(x, 123.0)
for i, sample in enumerate(samples):
if sample.broadcasts_input or sample.input.numel() == 0:
continue
try:
torch.compiler.reset()
# Setup: x starts with requires_grad=False, one arg has requires_grad=True
x_eager = sample.input.clone().detach()
args_eager = [
arg.clone().detach() if isinstance(arg, torch.Tensor) else arg
for arg in sample.args
]
# Find a floating point tensor arg to set requires_grad=True
requires_grad_idx = None
for idx, arg in enumerate(args_eager):
if isinstance(arg, torch.Tensor) and arg.dtype.is_floating_point:
arg.requires_grad_(True)
requires_grad_idx = idx
break
if requires_grad_idx is None or x_eager.requires_grad:
continue
# Apply inplace op in eager mode
inplace_op(x_eager, *args_eager, **sample.kwargs)
output_eager = CustomAutograd.apply(x_eager)
output_eager.sum().backward()
# Setup compiled version
x_compiled = sample.input.clone().detach()
args_compiled = [
arg.clone().detach() if isinstance(arg, torch.Tensor) else arg
for arg in sample.args
]
args_compiled[requires_grad_idx].requires_grad_(True)
# Test 1: Verify that the metadata is propagated after the inplace op in compile time
def compile_time_check(ctx: ComptimeContext) -> None:
x = ctx.get_local("x")
x_fake = x.as_fake()
self.assertTrue(x_fake.requires_grad)
self.assertTrue(x._ComptimeVar__variable.requires_grad)
def fn(x, *args):
inplace_op(x, *args, **sample.kwargs)
comptime(compile_time_check)
r = CustomAutograd.apply(x)
return r
compiled_fn = torch.compile(fn, backend="eager", fullgraph=True)
output_compiled = compiled_fn(x_compiled, *args_compiled)
output_compiled.sum().backward()
# Test 2: Verify requires_grad was propagated in runtime
self.assertEqual(
x_eager.requires_grad,
x_compiled.requires_grad,
msg=f"{op.name}: requires_grad mismatch (eager={x_eager.requires_grad}, compiled={x_compiled.requires_grad})",
)
# # Test 3: Verify gradients match
self.assertEqual(
args_eager[requires_grad_idx].grad,
args_compiled[requires_grad_idx].grad,
msg=f"{op.name}: Output mismatch indicates metadata not propagated during tracing",
)
except Exception as e:
# Skip known issue patterns
error_str = str(e).lower()
if any(
pattern in error_str
for pattern in [
"out=... arguments don't support automatic differentiation",
"the base given to", # dtype issue
"derivative for", # backward not implemented
]
):
continue
raise
instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True)
instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals())
@ -3036,6 +3159,7 @@ instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
instantiate_device_type_tests(TestFakeTensor, globals())
instantiate_device_type_tests(TestTags, globals())
instantiate_device_type_tests(TestForwardADWithScalars, globals())
instantiate_device_type_tests(TestTensorMetaProp, globals())
if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True

View File

@ -724,15 +724,28 @@ class TensorVariable(VariableTracker):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self, *args], kwargs),
),
proxy = tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self, *args], kwargs),
)
# [Note: Inplace ops and VariableTracker metadata]
# For inplace operations (methods ending with _), we need to propagate
# tensor metadata from the arguments to self. For example:
# x.add_(y) where y.requires_grad=True => x.requires_grad becomes True
# This is similar to the fix in method___setitem__.
if name.endswith("_") and args:
inplace_idx = 0
while inplace_idx < len(args) and not isinstance(
args[inplace_idx], TensorVariable
):
inplace_idx += 1
if inplace_idx < len(args):
self._propagate_inplace_metadata(tx, proxy, args[inplace_idx])
return wrap_fx_proxy(tx, proxy)
def method_size(self, *args, **kwargs):
return self._method_size_stride("size", *args, **kwargs)
@ -1101,6 +1114,31 @@ class TensorVariable(VariableTracker):
{},
)
def _propagate_inplace_metadata(self, tx, proxy, source_var):
"""
Propagate tensor metadata from source_var to self after an inplace operation.
This ensures that properties like requires_grad are correctly tracked during tracing.
Args:
tx: InstructionTranslator instance
proxy: The proxy node representing the inplace operation
source_var: The source TensorVariable whose metadata should be propagated
"""
# Ignore fresh unbacked symbols that could arise during the operation
with (
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(),
tx.fake_mode.shape_env.ignore_fresh_unbacked_symbols()
if tx.fake_mode and tx.fake_mode.shape_env
else nullcontext(),
):
get_fake_value(proxy.node, tx, allow_non_graph_fake=False)
vt = source_var
if isinstance(vt, variables.lazy.LazyVariableTracker):
vt = variables.lazy.LazyVariableTracker.realize_all(vt)
self.synchronize_attributes(tx, type(vt))
def method___setitem__(self, key, value):
from ..symbolic_convert import InstructionTranslator
@ -1127,19 +1165,7 @@ class TensorVariable(VariableTracker):
# When the selection happens if idx is unbacked we allocate a new unbacked symbol for the
# storage offset in select_meta, but the output of the operation 'setitem' does not depend
# on the selection.
with (
torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(),
tx.fake_mode.shape_env.ignore_fresh_unbacked_symbols()
if tx.fake_mode and tx.fake_mode.shape_env
else nullcontext(),
):
get_fake_value(proxy.node, tx, allow_non_graph_fake=False)
vt = value
if isinstance(vt, variables.lazy.LazyVariableTracker):
vt = variables.lazy.LazyVariableTracker.realize_all(vt)
self.synchronize_attributes(tx, type(vt))
self._propagate_inplace_metadata(tx, proxy, value)
if config.use_graph_deduplication or config.track_nodes_for_deduplication:
tx.output.region_tracker.add_node_mutation(proxy.node, 0)