Add support for list, tuple and dict in numeric debugger (#143882)

Summary:
Previously numeric debugger only supports torch.Tensor, this PR adds support for list, tuple and dict as well

Test Plan:
python test/test_quantization.py -k test_extract_results_from_loggers_list_output

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D67660049](https://our.internmc.facebook.com/intern/diff/D67660049)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143882
Approved by: https://github.com/dulinriley
This commit is contained in:
Jerry Zhang
2024-12-27 10:12:34 -08:00
committed by PyTorch MergeBot
parent c3c27aef34
commit ad78edee8e
3 changed files with 118 additions and 29 deletions

View File

@ -267,6 +267,36 @@ class TestNumericDebugger(TestCase):
if len(node_summary.results) > 0: if len(node_summary.results) > 0:
self.assertGreaterEqual(node_summary.results[0].sqnr, 35) self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
def test_extract_results_from_loggers_list_output(self):
m = TestHelperModules.Conv2dWithSplit()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False)
)
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m_quant_logger = prepare_for_propagation_comparison(m)
m_ref_logger(*example_inputs)
m_quant_logger(*example_inputs)
ref_results = extract_results_from_loggers(m_ref_logger)
quant_results = extract_results_from_loggers(m_quant_logger)
comparison_results = compare_results(ref_results, quant_results)
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
sqnr = node_summary.results[0].sqnr
if isinstance(sqnr, list):
for sqnr_i in sqnr:
self.assertGreaterEqual(sqnr_i, 35)
else:
self.assertGreaterEqual(sqnr, 35)
def test_added_node_gets_unique_id(self) -> None: def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d() m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs() example_inputs = m.example_inputs()

View File

@ -71,6 +71,53 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
bfs_trace_with_node_process(ep, _assign_debug_handle) bfs_trace_with_node_process(ep, _assign_debug_handle)
def _detach(x: object) -> object:
detached: object = None
if isinstance(x, torch.Tensor):
detached = x.detach()
elif isinstance(x, (list, tuple)):
detached = type(x)([_detach(e) for e in x])
elif isinstance(x, dict):
detached = {k: _detach(e) for k, e in x.items()}
else:
detached = x
return detached
def _tensor_shape_equals(x: object, y: object) -> bool:
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
return x.shape == y.shape
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y))
elif isinstance(x, dict) and isinstance(y, dict):
all_equal = True
for k in x:
all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k]))
return all_equal
else:
print(f"Comparing non Tensors: {x} and {y}, they must be equal")
return type(x) == type(y) and x == y
def _loss_fn(
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object
) -> object:
"""The returned loss will have the same structure as `x` and `y`, e.g.
if both are Tensor, we'll return a Tensor
if both are list, we'll return a list of Tensors
if both are dict, we'll return a dict with the same key, and value being the loss between the
two Tensors
"""
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
return loss(x.to(torch.float32), y.to(torch.float32))
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)])
elif isinstance(x, dict) and isinstance(y, dict):
return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()}
else:
return None
class OutputLogger(torch.nn.Module): class OutputLogger(torch.nn.Module):
""" """
Base class for capturing output values for nodes in a GraphModule, it only captures Base class for capturing output values for nodes in a GraphModule, it only captures
@ -90,11 +137,10 @@ class OutputLogger(torch.nn.Module):
self.node_name = node_name self.node_name = node_name
self.nn_module_stack = nn_module_stack self.nn_module_stack = nn_module_stack
self.debug_handle = debug_handle self.debug_handle = debug_handle
self.stats: List[torch.Tensor] = [] self.stats: List[object] = []
def forward(self, x: object) -> object: def forward(self, x: object) -> object:
if isinstance(x, torch.Tensor): self.stats.append(_detach(x))
self.stats.append(x.detach())
return x return x
def __extra_repr__(self) -> str: def __extra_repr__(self) -> str:
@ -162,27 +208,17 @@ class QuantizationComparisonResult:
ref: torch.Tensor ref: torch.Tensor
@property @property
def mse_loss(self) -> torch.Tensor: def mse_loss(self) -> object:
return F.mse_loss( return self.loss(F.mse_loss)
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
@property @property
def sqnr(self) -> torch.Tensor: def sqnr(self) -> object:
return compute_sqnr( return self.loss(compute_sqnr)
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def loss( def loss(
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) -> torch.Tensor: ) -> object:
if self.actual.shape != self.ref.shape: return _loss_fn(loss_function, self.actual, self.ref)
raise ValueError(
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
)
return loss_function(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def __repr__(self) -> str: def __repr__(self) -> str:
# Don't include the tensors themselves as they are quite large to print # Don't include the tensors themselves as they are quite large to print
@ -192,16 +228,19 @@ class QuantizationComparisonResult:
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
if not isinstance(self.actual, torch.Tensor): if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)):
raise ValueError( raise ValueError(
f"`self.actual` value must be a Tensor, got: {self.actual}" f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}"
) )
if not isinstance(self.ref, torch.Tensor): if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)):
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
if self.actual.shape != self.ref.shape:
raise ValueError( raise ValueError(
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}" f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}"
)
if not _tensor_shape_equals(self.ref, self.actual):
raise ValueError(
f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}"
) )
@ -231,14 +270,18 @@ def _module_stack_to_str(module_stack: object) -> str:
def extract_results_from_loggers( def extract_results_from_loggers(
model: GraphModule, model: GraphModule,
) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]: ) -> Dict[int, Tuple[Optional[str], object, List[object]]]:
"""For a given model, extract the tensors stats and related information for each debug handle. """For a given model, extract the tensors stats and related information for each debug handle.
The reason we have a list of object, instead of Tensor is because the output of node may not be
a Tensor, it could be (nested) list, tuple or dict as well.
Returns: Returns:
A dict is keyed by the debug_handle id and the values are a list of Tensors recorded A dict is keyed by the debug_handle id and the values are a list of object recorded
in loggers""" in loggers
"""
# Results maps debug handle to a tensor list for each model being compared. # Results maps debug handle to a tensor list for each model being compared.
handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {} handles: Dict[int, Tuple[Optional[str], object, List[object]]] = {}
for _name, module in model.named_children(): for _name, module in model.named_children():
if isinstance(module, OutputLogger) and len(module.stats) > 0: if isinstance(module, OutputLogger) and len(module.stats) > 0:
handles[module.debug_handle] = ( handles[module.debug_handle] = (

View File

@ -2930,6 +2930,22 @@ class TestHelperModules:
w = torch.cat([z, y]) w = torch.cat([z, y])
return w return w
class Conv2dWithSplit(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
x = self.conv1(x)
# use split so we get a list of Tensors
x1, x2 = torch.split(x, 2, dim=1)
y = torch.cat([x1, x2], dim=1)
return y
def example_inputs(self):
return (torch.randn(1, 3, 16, 16),)
class ThreeAdd(torch.nn.Module): class ThreeAdd(torch.nn.Module):
def forward(self, x1, x2, x3, x4): def forward(self, x1, x2, x3, x4):
y = x1 + x2 y = x1 + x2