[graph_manipulation] Set fused dtypes for all constant params/buffers (#77401)

Summary: We were handling constant attrs in a few different ways before, leading to confusion and missed handing for fused dtypes. This diff consolidates some of that code and unbreaks current breakage.

Test Plan: CI. Recently broken tests now pass.

Differential Revision: D36335238

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77401
Approved by: https://github.com/jaybean-dev, https://github.com/jamesr66a
This commit is contained in:
Jordan Fix
2022-05-17 07:42:29 +00:00
committed by PyTorch MergeBot
parent 942f04172a
commit 18e36a6295
2 changed files with 62 additions and 65 deletions

View File

@ -123,7 +123,7 @@ class TestFXExperimental(JitTestCase):
assert len(serialized_graph1["weights"]) == 4
assert len(serialized_graph1["modules"]) == 0
assert len(serialized_graph2["nodes"]) == 6
assert len(serialized_graph2["weights"]) == 4
assert len(serialized_graph2["weights"]) == 1
assert len(serialized_graph2["modules"]) == 1
assert serialized_graph1["weights"]["linear.weight"]["shape"] == "[4, 4]"
assert serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32"