mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[graph_manipulation] Set fused dtypes for all constant params/buffers (#77401)
Summary: We were handling constant attrs in a few different ways before, leading to confusion and missed handing for fused dtypes. This diff consolidates some of that code and unbreaks current breakage. Test Plan: CI. Recently broken tests now pass. Differential Revision: D36335238 Pull Request resolved: https://github.com/pytorch/pytorch/pull/77401 Approved by: https://github.com/jaybean-dev, https://github.com/jamesr66a
This commit is contained in:
committed by
PyTorch MergeBot
parent
942f04172a
commit
18e36a6295
@ -123,7 +123,7 @@ class TestFXExperimental(JitTestCase):
|
||||
assert len(serialized_graph1["weights"]) == 4
|
||||
assert len(serialized_graph1["modules"]) == 0
|
||||
assert len(serialized_graph2["nodes"]) == 6
|
||||
assert len(serialized_graph2["weights"]) == 4
|
||||
assert len(serialized_graph2["weights"]) == 1
|
||||
assert len(serialized_graph2["modules"]) == 1
|
||||
assert serialized_graph1["weights"]["linear.weight"]["shape"] == "[4, 4]"
|
||||
assert serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32"
|
||||
|
||||
Reference in New Issue
Block a user