mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add support for capturing provenance of unary operations (#146413)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146413 Approved by: https://github.com/angelayi ghstack dependencies: #145848
This commit is contained in:
committed by
PyTorch MergeBot
parent
07b9fe0690
commit
bc33d993ac
@ -1639,19 +1639,6 @@ def _make_user_magic(method, user_type):
|
||||
other = torch.sym_float(other)
|
||||
return self, other
|
||||
|
||||
# Before and after performing the operation, check if any operands are constant.
|
||||
# If so, extract out the constant values first. If `self` itself is a
|
||||
# constant, then "redispatch" by calling back into the operator. Sometimes
|
||||
# this means that operations involving SymBool return plain bools.
|
||||
# Alternatively, we could also rewrap into constant Symbool (i.e. by
|
||||
# implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
|
||||
# today for no particular reason.
|
||||
def unary_magic_impl(self):
|
||||
self = promote(self)
|
||||
if is_constant(self):
|
||||
return (method_to_operator(method))(get_constant(self))
|
||||
return wrap_node(getattr(self.node, method_attr)())
|
||||
|
||||
def uninteresting_files() -> Set[str]:
|
||||
import inspect
|
||||
|
||||
@ -1673,8 +1660,11 @@ def _make_user_magic(method, user_type):
|
||||
|
||||
def capture_provenance(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(self, other):
|
||||
result = fn(self, other)
|
||||
def wrapper(self, other=None):
|
||||
if other is None:
|
||||
result = fn(self)
|
||||
else:
|
||||
result = fn(self, other)
|
||||
if torch._logging._internal.GET_DTRACE_STRUCTURED:
|
||||
floc = None
|
||||
user_stack = None
|
||||
@ -1715,11 +1705,16 @@ def _make_user_magic(method, user_type):
|
||||
finally:
|
||||
del frame
|
||||
|
||||
if other:
|
||||
arguments = [str(self), str(other)]
|
||||
else:
|
||||
arguments = [str(self)]
|
||||
|
||||
dtrace_structured(
|
||||
"expression_created",
|
||||
metadata_fn=lambda: {
|
||||
"method": method,
|
||||
"arguments": [str(self), str(other)],
|
||||
"arguments": arguments,
|
||||
"result": str(result),
|
||||
"user_bottom_stack": str(user_bottom_stack),
|
||||
"user_top_stack": str(user_top_stack),
|
||||
@ -1731,6 +1726,20 @@ def _make_user_magic(method, user_type):
|
||||
|
||||
return wrapper
|
||||
|
||||
# Before and after performing the operation, check if any operands are constant.
|
||||
# If so, extract out the constant values first. If `self` itself is a
|
||||
# constant, then "redispatch" by calling back into the operator. Sometimes
|
||||
# this means that operations involving SymBool return plain bools.
|
||||
# Alternatively, we could also rewrap into constant Symbool (i.e. by
|
||||
# implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
|
||||
# today for no particular reason.
|
||||
@capture_provenance
|
||||
def unary_magic_impl(self):
|
||||
self = promote(self)
|
||||
if is_constant(self):
|
||||
return (method_to_operator(method))(get_constant(self))
|
||||
return wrap_node(getattr(self.node, method_attr)())
|
||||
|
||||
@capture_provenance
|
||||
def binary_magic_impl(self, other):
|
||||
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
|
||||
|
||||
Reference in New Issue
Block a user