mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add support for list, tuple and dict in numeric debugger (#143882)
Summary: Previously numeric debugger only supports torch.Tensor, this PR adds support for list, tuple and dict as well Test Plan: python test/test_quantization.py -k test_extract_results_from_loggers_list_output Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D67660049](https://our.internmc.facebook.com/intern/diff/D67660049) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143882 Approved by: https://github.com/dulinriley
This commit is contained in:
committed by
PyTorch MergeBot
parent
c3c27aef34
commit
ad78edee8e
@ -267,6 +267,36 @@ class TestNumericDebugger(TestCase):
|
|||||||
if len(node_summary.results) > 0:
|
if len(node_summary.results) > 0:
|
||||||
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
|
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
|
||||||
|
|
||||||
|
def test_extract_results_from_loggers_list_output(self):
|
||||||
|
m = TestHelperModules.Conv2dWithSplit()
|
||||||
|
example_inputs = m.example_inputs()
|
||||||
|
ep = export_for_training(m, example_inputs)
|
||||||
|
generate_numeric_debug_handle(ep)
|
||||||
|
m = ep.module()
|
||||||
|
m_ref_logger = prepare_for_propagation_comparison(m)
|
||||||
|
|
||||||
|
quantizer = XNNPACKQuantizer().set_global(
|
||||||
|
get_symmetric_quantization_config(is_per_channel=False)
|
||||||
|
)
|
||||||
|
m = prepare_pt2e(m, quantizer)
|
||||||
|
m(*example_inputs)
|
||||||
|
m = convert_pt2e(m)
|
||||||
|
m_quant_logger = prepare_for_propagation_comparison(m)
|
||||||
|
|
||||||
|
m_ref_logger(*example_inputs)
|
||||||
|
m_quant_logger(*example_inputs)
|
||||||
|
ref_results = extract_results_from_loggers(m_ref_logger)
|
||||||
|
quant_results = extract_results_from_loggers(m_quant_logger)
|
||||||
|
comparison_results = compare_results(ref_results, quant_results)
|
||||||
|
for node_summary in comparison_results.values():
|
||||||
|
if len(node_summary.results) > 0:
|
||||||
|
sqnr = node_summary.results[0].sqnr
|
||||||
|
if isinstance(sqnr, list):
|
||||||
|
for sqnr_i in sqnr:
|
||||||
|
self.assertGreaterEqual(sqnr_i, 35)
|
||||||
|
else:
|
||||||
|
self.assertGreaterEqual(sqnr, 35)
|
||||||
|
|
||||||
def test_added_node_gets_unique_id(self) -> None:
|
def test_added_node_gets_unique_id(self) -> None:
|
||||||
m = TestHelperModules.Conv2dThenConv1d()
|
m = TestHelperModules.Conv2dThenConv1d()
|
||||||
example_inputs = m.example_inputs()
|
example_inputs = m.example_inputs()
|
||||||
|
@ -71,6 +71,53 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
|
|||||||
bfs_trace_with_node_process(ep, _assign_debug_handle)
|
bfs_trace_with_node_process(ep, _assign_debug_handle)
|
||||||
|
|
||||||
|
|
||||||
|
def _detach(x: object) -> object:
|
||||||
|
detached: object = None
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
detached = x.detach()
|
||||||
|
elif isinstance(x, (list, tuple)):
|
||||||
|
detached = type(x)([_detach(e) for e in x])
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
detached = {k: _detach(e) for k, e in x.items()}
|
||||||
|
else:
|
||||||
|
detached = x
|
||||||
|
return detached
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_shape_equals(x: object, y: object) -> bool:
|
||||||
|
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||||
|
return x.shape == y.shape
|
||||||
|
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||||
|
return all(_tensor_shape_equals(e1, e2) for e1, e2 in zip(x, y))
|
||||||
|
elif isinstance(x, dict) and isinstance(y, dict):
|
||||||
|
all_equal = True
|
||||||
|
for k in x:
|
||||||
|
all_equal = all_equal and k in y and (_tensor_shape_equals(x[k], y[k]))
|
||||||
|
return all_equal
|
||||||
|
else:
|
||||||
|
print(f"Comparing non Tensors: {x} and {y}, they must be equal")
|
||||||
|
return type(x) == type(y) and x == y
|
||||||
|
|
||||||
|
|
||||||
|
def _loss_fn(
|
||||||
|
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: object, y: object
|
||||||
|
) -> object:
|
||||||
|
"""The returned loss will have the same structure as `x` and `y`, e.g.
|
||||||
|
if both are Tensor, we'll return a Tensor
|
||||||
|
if both are list, we'll return a list of Tensors
|
||||||
|
if both are dict, we'll return a dict with the same key, and value being the loss between the
|
||||||
|
two Tensors
|
||||||
|
"""
|
||||||
|
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||||
|
return loss(x.to(torch.float32), y.to(torch.float32))
|
||||||
|
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||||
|
return type(x)([_loss_fn(loss, e1, e2) for e1, e2 in zip(x, y)])
|
||||||
|
elif isinstance(x, dict) and isinstance(y, dict):
|
||||||
|
return {k: _loss_fn(loss, e, y[k]) for k, e in x.items()}
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class OutputLogger(torch.nn.Module):
|
class OutputLogger(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Base class for capturing output values for nodes in a GraphModule, it only captures
|
Base class for capturing output values for nodes in a GraphModule, it only captures
|
||||||
@ -90,11 +137,10 @@ class OutputLogger(torch.nn.Module):
|
|||||||
self.node_name = node_name
|
self.node_name = node_name
|
||||||
self.nn_module_stack = nn_module_stack
|
self.nn_module_stack = nn_module_stack
|
||||||
self.debug_handle = debug_handle
|
self.debug_handle = debug_handle
|
||||||
self.stats: List[torch.Tensor] = []
|
self.stats: List[object] = []
|
||||||
|
|
||||||
def forward(self, x: object) -> object:
|
def forward(self, x: object) -> object:
|
||||||
if isinstance(x, torch.Tensor):
|
self.stats.append(_detach(x))
|
||||||
self.stats.append(x.detach())
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __extra_repr__(self) -> str:
|
def __extra_repr__(self) -> str:
|
||||||
@ -162,27 +208,17 @@ class QuantizationComparisonResult:
|
|||||||
ref: torch.Tensor
|
ref: torch.Tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mse_loss(self) -> torch.Tensor:
|
def mse_loss(self) -> object:
|
||||||
return F.mse_loss(
|
return self.loss(F.mse_loss)
|
||||||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sqnr(self) -> torch.Tensor:
|
def sqnr(self) -> object:
|
||||||
return compute_sqnr(
|
return self.loss(compute_sqnr)
|
||||||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
def loss(
|
def loss(
|
||||||
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
self, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
||||||
) -> torch.Tensor:
|
) -> object:
|
||||||
if self.actual.shape != self.ref.shape:
|
return _loss_fn(loss_function, self.actual, self.ref)
|
||||||
raise ValueError(
|
|
||||||
f"Cannot compare tensors with different shapes: {self.actual.shape} vs {self.ref.shape}"
|
|
||||||
)
|
|
||||||
return loss_function(
|
|
||||||
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
# Don't include the tensors themselves as they are quite large to print
|
# Don't include the tensors themselves as they are quite large to print
|
||||||
@ -192,16 +228,19 @@ class QuantizationComparisonResult:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if not isinstance(self.actual, torch.Tensor):
|
if not isinstance(self.actual, (torch.Tensor, list, tuple, dict)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`self.actual` value must be a Tensor, got: {self.actual}"
|
f"`self.actual` value must be a Tensor, list, tuple or dict, got: {self.actual}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(self.ref, torch.Tensor):
|
if not isinstance(self.ref, (torch.Tensor, list, tuple, dict)):
|
||||||
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
|
|
||||||
if self.actual.shape != self.ref.shape:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot compare tensors with different shapes: ref={self.ref.shape} vs actual={self.actual.shape}"
|
f"`self.ref` value must be a Tensor, list, tuple or dict, got: {self.ref}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _tensor_shape_equals(self.ref, self.actual):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot compare tensors with different shapes: ref={self.ref} vs actual={self.actual}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -231,14 +270,18 @@ def _module_stack_to_str(module_stack: object) -> str:
|
|||||||
|
|
||||||
def extract_results_from_loggers(
|
def extract_results_from_loggers(
|
||||||
model: GraphModule,
|
model: GraphModule,
|
||||||
) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]:
|
) -> Dict[int, Tuple[Optional[str], object, List[object]]]:
|
||||||
"""For a given model, extract the tensors stats and related information for each debug handle.
|
"""For a given model, extract the tensors stats and related information for each debug handle.
|
||||||
|
The reason we have a list of object, instead of Tensor is because the output of node may not be
|
||||||
|
a Tensor, it could be (nested) list, tuple or dict as well.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dict is keyed by the debug_handle id and the values are a list of Tensors recorded
|
A dict is keyed by the debug_handle id and the values are a list of object recorded
|
||||||
in loggers"""
|
in loggers
|
||||||
|
|
||||||
|
"""
|
||||||
# Results maps debug handle to a tensor list for each model being compared.
|
# Results maps debug handle to a tensor list for each model being compared.
|
||||||
handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {}
|
handles: Dict[int, Tuple[Optional[str], object, List[object]]] = {}
|
||||||
for _name, module in model.named_children():
|
for _name, module in model.named_children():
|
||||||
if isinstance(module, OutputLogger) and len(module.stats) > 0:
|
if isinstance(module, OutputLogger) and len(module.stats) > 0:
|
||||||
handles[module.debug_handle] = (
|
handles[module.debug_handle] = (
|
||||||
|
@ -2930,6 +2930,22 @@ class TestHelperModules:
|
|||||||
w = torch.cat([z, y])
|
w = torch.cat([z, y])
|
||||||
return w
|
return w
|
||||||
|
|
||||||
|
class Conv2dWithSplit(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = torch.nn.Conv2d(3, 3, 3)
|
||||||
|
self.conv2 = torch.nn.Conv2d(3, 3, 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
# use split so we get a list of Tensors
|
||||||
|
x1, x2 = torch.split(x, 2, dim=1)
|
||||||
|
y = torch.cat([x1, x2], dim=1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def example_inputs(self):
|
||||||
|
return (torch.randn(1, 3, 16, 16),)
|
||||||
|
|
||||||
class ThreeAdd(torch.nn.Module):
|
class ThreeAdd(torch.nn.Module):
|
||||||
def forward(self, x1, x2, x3, x4):
|
def forward(self, x1, x2, x3, x4):
|
||||||
y = x1 + x2
|
y = x1 + x2
|
||||||
|
Reference in New Issue
Block a user