add support for capturing provenance of unary operations (#146413)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146413
Approved by: https://github.com/angelayi
ghstack dependencies: #145848
This commit is contained in:
bobrenjc93
2025-02-04 09:36:27 -08:00
committed by PyTorch MergeBot
parent 07b9fe0690
commit bc33d993ac

View File

@ -1639,19 +1639,6 @@ def _make_user_magic(method, user_type):
other = torch.sym_float(other)
return self, other
# Before and after performing the operation, check if any operands are constant.
# If so, extract out the constant values first. If `self` itself is a
# constant, then "redispatch" by calling back into the operator. Sometimes
# this means that operations involving SymBool return plain bools.
# Alternatively, we could also rewrap into constant Symbool (i.e. by
# implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
# today for no particular reason.
def unary_magic_impl(self):
self = promote(self)
if is_constant(self):
return (method_to_operator(method))(get_constant(self))
return wrap_node(getattr(self.node, method_attr)())
def uninteresting_files() -> Set[str]:
import inspect
@ -1673,8 +1660,11 @@ def _make_user_magic(method, user_type):
def capture_provenance(fn):
@functools.wraps(fn)
def wrapper(self, other):
result = fn(self, other)
def wrapper(self, other=None):
if other is None:
result = fn(self)
else:
result = fn(self, other)
if torch._logging._internal.GET_DTRACE_STRUCTURED:
floc = None
user_stack = None
@ -1715,11 +1705,16 @@ def _make_user_magic(method, user_type):
finally:
del frame
if other:
arguments = [str(self), str(other)]
else:
arguments = [str(self)]
dtrace_structured(
"expression_created",
metadata_fn=lambda: {
"method": method,
"arguments": [str(self), str(other)],
"arguments": arguments,
"result": str(result),
"user_bottom_stack": str(user_bottom_stack),
"user_top_stack": str(user_top_stack),
@ -1731,6 +1726,20 @@ def _make_user_magic(method, user_type):
return wrapper
# Before and after performing the operation, check if any operands are constant.
# If so, extract out the constant values first. If `self` itself is a
# constant, then "redispatch" by calling back into the operator. Sometimes
# this means that operations involving SymBool return plain bools.
# Alternatively, we could also rewrap into constant Symbool (i.e. by
# implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
# today for no particular reason.
@capture_provenance
def unary_magic_impl(self):
self = promote(self)
if is_constant(self):
return (method_to_operator(method))(get_constant(self))
return wrap_node(getattr(self.node, method_attr)())
@capture_provenance
def binary_magic_impl(self, other):
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):