[export] improve error message for deserializing custom triton op (#152029)

In https://github.com/pytorch/pytorch/issues/151746, users ran into an error where a custom triton op cannot be resolved into an operator from string target. We improve the error message by reminding users to register the same custom operator at de-serialization time.

Now the error looks like this:
```python
torch._export.serde.serialize.SerializeError: We failed to resolve torch.ops.triton_kernel.add.default to an operator. If it's a custom op/custom triton op, this is usally because the custom op is not registered when deserializing. Please import the custom op to register it before deserializing. Otherwise, please file an issue on github. Unsupported target type for node Node(target='torch.ops.triton_kernel.add.default', inputs=[NamedArgument(name='x', arg=Argument(as_tensor=TensorArgument(name='linear')), kind=1), NamedArgument(name='y', arg=Argument(as_tensor=TensorArgument(name='mul')), kind=1)], outputs=[Argument(as_tensor=TensorArgument(name='add'))], metadata={'stack_trace': 'File "/data/users/yidi/pytorch/test.py", line 50, in forward\n    output = triton_add(dense_output, bias)', 'nn_module_stack': 'L__self__,,__main__.SimpleModel', 'torch_fn': 'add.default_1;OpOverload.add.default'}, is_hop_single_tensor_return=None): <class 'str'>.```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152029
Approved by: https://github.com/jingsh
This commit is contained in:
Yidi Wu
2025-04-23 11:14:09 -07:00
committed by PyTorch MergeBot
parent 24bda01a93
commit 92f125e622

View File

@ -1986,8 +1986,12 @@ class GraphModuleDeserializer(metaclass=Final):
)
self.deserialize_outputs(serialized_node, fx_node)
else:
_additional_msg = (f"We failed to resolve {target} to an operator. "
+ "If it's a custom op/custom triton op, this is usally because the custom op is not registered"
+ " when deserializing. Please import the custom op to register it before deserializing."
+ " Otherwise, please file an issue on github.") if isinstance(target, str) else ""
raise SerializeError(
f"Unsupported target type for node {serialized_node}: {type(target)}"
_additional_msg + f" Unsupported target type for node {serialized_node}: {type(target)}."
)
fx_node.meta.update(self.deserialize_metadata(serialized_node.metadata))