Fix codegen, change str comparison opeator to == for proper equality … (#150611)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150611
Approved by: https://github.com/Skylion007, https://github.com/cyyever
This commit is contained in:
Jakub Grzybek
2025-04-04 09:59:59 +00:00
committed by PyTorch MergeBot
parent 4854926aeb
commit 73358d37da
2 changed files with 18 additions and 1 deletions

View File

@ -1271,6 +1271,23 @@ class TestFX(JitTestCase):
"call_module"
).check("clamp").check("call_method").run(all_formatted)
def test_print_graph(self):
op: torch._ops.OpOverload = torch.ops.aten.relu.default
type_name: str = torch.typename(op)
graph: torch.fx.Graph = torch.fx.Graph()
a: torch.fx.Node = graph.create_node("placeholder", "x")
b: torch.fx.Node = graph.create_node("call_function", op, (a,), type_expr=type_name)
c: torch.fx.Node = graph.create_node("call_function", op, (b,), type_expr=type_name)
graph.output((b, c))
gm: torch.fx.GraphModule = torch.fx.GraphModule(
torch.nn.Module(), graph
)
gm.graph.lint()
text = gm.print_readable(False)
assert 2 == text.count("_torch__ops_aten_aten_relu_")
def test_script_tensor_constant(self):
# TorchScript seems to ignore attributes that start with `__`.
# We used to call anonymous Tensor values `__tensor_constant*`, but