mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add support for list, tuple and dict in numeric debugger (#143882)
Summary: Previously numeric debugger only supports torch.Tensor, this PR adds support for list, tuple and dict as well Test Plan: python test/test_quantization.py -k test_extract_results_from_loggers_list_output Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D67660049](https://our.internmc.facebook.com/intern/diff/D67660049) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143882 Approved by: https://github.com/dulinriley
This commit is contained in:
committed by
PyTorch MergeBot
parent
c3c27aef34
commit
ad78edee8e
@ -267,6 +267,36 @@ class TestNumericDebugger(TestCase):
|
||||
if len(node_summary.results) > 0:
|
||||
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
|
||||
|
||||
def test_extract_results_from_loggers_list_output(self):
|
||||
m = TestHelperModules.Conv2dWithSplit()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
m_ref_logger = prepare_for_propagation_comparison(m)
|
||||
|
||||
quantizer = XNNPACKQuantizer().set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=False)
|
||||
)
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m_quant_logger = prepare_for_propagation_comparison(m)
|
||||
|
||||
m_ref_logger(*example_inputs)
|
||||
m_quant_logger(*example_inputs)
|
||||
ref_results = extract_results_from_loggers(m_ref_logger)
|
||||
quant_results = extract_results_from_loggers(m_quant_logger)
|
||||
comparison_results = compare_results(ref_results, quant_results)
|
||||
for node_summary in comparison_results.values():
|
||||
if len(node_summary.results) > 0:
|
||||
sqnr = node_summary.results[0].sqnr
|
||||
if isinstance(sqnr, list):
|
||||
for sqnr_i in sqnr:
|
||||
self.assertGreaterEqual(sqnr_i, 35)
|
||||
else:
|
||||
self.assertGreaterEqual(sqnr, 35)
|
||||
|
||||
def test_added_node_gets_unique_id(self) -> None:
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
|
@ -71,6 +71,53 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
|
||||
bfs_trace_with_node_process(ep, _assign_debug_handle)
|
||||
|
||||
|
||||
def _detach(x: object) -> object:
|
||||
detached: object = None
|
||||
if isinstance(x, torch.Tensor):
|
||||
detached = x.detach()
|
||||
elif isinstance(x, (list, tuple)):
|
||||
detached = type(x)([_detach(e) for e in x])
|
||||
elif isinstance(x, dict):
|
||||
detached = {k: _detach(e) for k, e in x.items()}
|
||||
else:
|
||||
detached = x
|
||||
return detached
|
||||
|
||||
|
||||
def _tensor_shape_equals(x: object, y: object) -> bool:
|
||||
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||
return x.shape == y.shape
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y))
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
all_equal = True
|
||||
for k in x:
|
||||
all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k]))
|
||||
return all_equal
|
||||
else:
|
||||
print(f"Comparing non Tensors: {x} and {y}, they must be equal")
|
||||
return type(x) == type(y) and x == y
|
||||
|
||||
|
||||
def _loss_fn(
|
||||
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object
|
||||
) -> object:
|
||||
"""The returned loss will have the same structure as `x` and `y`, e.g.
|
||||
if both are Tensor, we'll return a Tensor
|
||||
if both are list, we'll return a list of Tensors
|
||||
if both are dict, we'll return a dict with the same key, and value being the loss between the
|
||||
two Tensors
|
||||
"""
|
||||
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||
return loss(x.to(torch.float32), y.to(torch.float32))
|
||||
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||
return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)])
|
||||
elif isinstance(x, dict) and isinstance(y, dict):
|
||||
return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()}
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class OutputLogger(torch.nn.Module):
|
||||
"""
|
||||
Base class for capturing output values for nodes in a GraphModule, it only captures
|
||||
@ -90,11 +137,10 @@ class OutputLogger(torch.nn.Module):
|
||||
self.node_name = node_name
|
||||
self.nn_module_stack = nn_module_stack
|
||||
self.debug_handle = debug_handle
|
||||
self.stats: List[torch.Tensor] = []
|
||||
self.stats: List[object] = []
|
||||
|
||||
def forward(self, x: object) -> object:
|
||||
if isinstance(x, torch.Tensor):
|
||||
self.stats.append(x.detach())
|
||||
self.stats.append(_detach(x))
|
||||
return x
|
||||
|
||||
def __extra_repr__(self) -> str:
|
||||
@ -162,27 +208,17 @@ class QuantizationComparisonResult:
|
||||
ref: torch.Tensor
|
||||
|
||||
@property
|
||||
def mse_loss(self) -> torch.Tensor:
|
||||
return F.mse_loss(
|
||||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||
)
|
||||
def mse_loss(self) -> object:
|
||||
return self.loss(F.mse_loss)
|
||||
|
||||
@property
|
||||
def sqnr(self) -> torch.Tensor:
|
||||
return compute_sqnr(
|
||||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||
)
|
||||
def sqnr(self) -> object:
|
||||
return self.loss(compute_sqnr)
|
||||
|
||||
def loss(
|
||||
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
if self.actual.shape != self.ref.shape:
|
||||
raise ValueError(
|
||||
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
|
||||
)
|
||||
return loss_function(
|
||||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||
)
|
||||
) -> object:
|
||||
return _loss_fn(loss_function, self.actual, self.ref)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# Don't include the tensors themselves as they are quite large to print
|
||||
@ -192,16 +228,19 @@ class QuantizationComparisonResult:
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not isinstance(self.actual, torch.Tensor):
|
||||
if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)):
|
||||
raise ValueError(
|
||||
f"`self.actual` value must be a Tensor, got: {self.actual}"
|
||||
f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}"
|
||||
)
|
||||
|
||||
if not isinstance(self.ref, torch.Tensor):
|
||||
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
|
||||
if self.actual.shape != self.ref.shape:
|
||||
if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)):
|
||||
raise ValueError(
|
||||
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
|
||||
f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}"
|
||||
)
|
||||
|
||||
if not _tensor_shape_equals(self.ref, self.actual):
|
||||
raise ValueError(
|
||||
f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}"
|
||||
)
|
||||
|
||||
|
||||
@ -231,14 +270,18 @@ def _module_stack_to_str(module_stack: object) -> str:
|
||||
|
||||
def extract_results_from_loggers(
|
||||
model: GraphModule,
|
||||
) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]:
|
||||
) -> Dict[int, Tuple[Optional[str], object, List[object]]]:
|
||||
"""For a given model, extract the tensors stats and related information for each debug handle.
|
||||
The reason we have a list of object, instead of Tensor is because the output of node may not be
|
||||
a Tensor, it could be (nested) list, tuple or dict as well.
|
||||
|
||||
Returns:
|
||||
A dict is keyed by the debug_handle id and the values are a list of Tensors recorded
|
||||
in loggers"""
|
||||
A dict is keyed by the debug_handle id and the values are a list of object recorded
|
||||
in loggers
|
||||
|
||||
"""
|
||||
# Results maps debug handle to a tensor list for each model being compared.
|
||||
handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {}
|
||||
handles: Dict[int, Tuple[Optional[str], object, List[object]]] = {}
|
||||
for _name, module in model.named_children():
|
||||
if isinstance(module, OutputLogger) and len(module.stats) > 0:
|
||||
handles[module.debug_handle] = (
|
||||
|
@ -2930,6 +2930,22 @@ class TestHelperModules:
|
||||
w = torch.cat([z, y])
|
||||
return w
|
||||
|
||||
class Conv2dWithSplit(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
||||
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
# use split so we get a list of Tensors
|
||||
x1, x2 = torch.split(x, 2, dim=1)
|
||||
y = torch.cat([x1, x2], dim=1)
|
||||
return y
|
||||
|
||||
def example_inputs(self):
|
||||
return (torch.randn(1, 3, 16, 16),)
|
||||
|
||||
class ThreeAdd(torch.nn.Module):
|
||||
def forward(self, x1, x2, x3, x4):
|
||||
y = x1 + x2
|
||||
|
Reference in New Issue
Block a user