mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add use_strict_trace to tensorboard add_graph method (#63120)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63120 FAIM returns dictionaries as the model output, which throws an error when trying to trace using add_graph. Pass in `strict` to the tracer to make this user configurable. User post: https://fb.workplace.com/groups/pytorchLightning/permalink/1510194972650369/?comment_id=1510252919311241&reply_comment_id=1510281112641755 Test Plan: unit test Reviewed By: Reubend Differential Revision: D30265890 fbshipit-source-id: 58b25d9500b875a29a664aa9ef4c1e7f13631fa1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1022443168
commit
96fb1a56ea
@ -569,6 +569,41 @@ class TestTensorBoardPytorchGraph(BaseTestCase):
|
||||
self.assertEquals(
|
||||
sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
|
||||
|
||||
def test_pytorch_graph_dict_input(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l = torch.nn.Linear(3, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.l(x)
|
||||
|
||||
class ModelDict(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.l = torch.nn.Linear(3, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return {"out": self.l(x)}
|
||||
|
||||
|
||||
dummy_input = torch.zeros(1, 3)
|
||||
|
||||
with self.createSummaryWriter() as w:
|
||||
w.add_graph(Model(), dummy_input)
|
||||
|
||||
with self.createSummaryWriter() as w:
|
||||
w.add_graph(Model(), dummy_input, use_strict_trace=True)
|
||||
|
||||
# expect error: Encountering a dict at the output of the tracer...
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.createSummaryWriter() as w:
|
||||
w.add_graph(ModelDict(), dummy_input, use_strict_trace=True)
|
||||
|
||||
with self.createSummaryWriter() as w:
|
||||
w.add_graph(ModelDict(), dummy_input, use_strict_trace=False)
|
||||
|
||||
|
||||
def test_mlp_graph(self):
|
||||
dummy_input = (torch.zeros(2, 1, 28, 28),)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user