mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Profiler] Torch Profiler distributed info is not JSON serializable (#135548)
Summary: To fix https://github.com/pytorch/pytorch/issues/133308 we must create an encoder for numpy values so we can serialize the distributed metadata to JSON. Test Plan: Added unit test to check that numpy values can be serialized Differential Revision: D62411619 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135548 Approved by: https://github.com/aaronenyeshi, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
8c356ce3da
commit
062681a0ed
@ -36,6 +36,28 @@ __all__ = [
|
||||
PROFILER_STEP_NAME = "ProfilerStep"
|
||||
|
||||
|
||||
class _NumpyEncoder(json.JSONEncoder):
|
||||
"""
|
||||
Json encoder for numpy types (np.int, np.float, np.array etc.)
|
||||
Returns default encoder if numpy is not available
|
||||
"""
|
||||
|
||||
def default(self, obj):
|
||||
"""Encode NumPy types to JSON"""
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
else:
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def supported_activities():
|
||||
"""
|
||||
Returns a set of supported profiler tracing activities.
|
||||
@ -187,7 +209,9 @@ class _KinetoProfile:
|
||||
if kineto_available():
|
||||
dist_info = self._get_distributed_info()
|
||||
if dist_info:
|
||||
self.add_metadata_json("distributedInfo", json.dumps(dist_info))
|
||||
self.add_metadata_json(
|
||||
"distributedInfo", json.dumps(dist_info, cls=_NumpyEncoder)
|
||||
)
|
||||
|
||||
if hasattr(torch, "_inductor"):
|
||||
import torch._inductor.config as inductor_config
|
||||
@ -931,5 +955,6 @@ class ExecutionTraceObserver(_ITraceObserver):
|
||||
):
|
||||
pg_config_info = torch.distributed.distributed_c10d._world.pg_config_info
|
||||
torch.autograd._record_function_with_args_enter(
|
||||
"## process_group:init ##", json.dumps(pg_config_info)
|
||||
"## process_group:init ##",
|
||||
json.dumps(pg_config_info, cls=_NumpyEncoder),
|
||||
)
|
||||
|
Reference in New Issue
Block a user