mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix codegen, change str comparison opeator to == for proper equality … (#150611)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150611 Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
committed by
PyTorch MergeBot
parent
4854926aeb
commit
73358d37da
@ -1271,6 +1271,23 @@ class TestFX(JitTestCase):
|
||||
"call_module"
|
||||
).check("clamp").check("call_method").run(all_formatted)
|
||||
|
||||
def test_print_graph(self):
|
||||
op: torch._ops.OpOverload = torch.ops.aten.relu.default
|
||||
type_name: str = torch.typename(op)
|
||||
|
||||
graph: torch.fx.Graph = torch.fx.Graph()
|
||||
a: torch.fx.Node = graph.create_node("placeholder", "x")
|
||||
b: torch.fx.Node = graph.create_node("call_function", op, (a,), type_expr=type_name)
|
||||
c: torch.fx.Node = graph.create_node("call_function", op, (b,), type_expr=type_name)
|
||||
graph.output((b, c))
|
||||
|
||||
gm: torch.fx.GraphModule = torch.fx.GraphModule(
|
||||
torch.nn.Module(), graph
|
||||
)
|
||||
gm.graph.lint()
|
||||
text = gm.print_readable(False)
|
||||
assert 2 == text.count("_torch__ops_aten_aten_relu_")
|
||||
|
||||
def test_script_tensor_constant(self):
|
||||
# TorchScript seems to ignore attributes that start with `__`.
|
||||
# We used to call anonymous Tensor values `__tensor_constant*`, but
|
||||
|
Reference in New Issue
Block a user