mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[RLHF] Fix torch.dtype not serializable in example (#22158)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@ -126,7 +126,10 @@ for name, p in train_model.named_parameters():
|
||||
|
||||
# Synchronize the updated weights to the inference engine.
|
||||
for name, p in train_model.named_parameters():
|
||||
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
|
||||
dtype_name = str(p.dtype).split(".")[-1]
|
||||
handle = llm.collective_rpc.remote(
|
||||
"update_weight", args=(name, dtype_name, p.shape)
|
||||
)
|
||||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||
ray.get(handle)
|
||||
|
||||
|
@ -45,7 +45,8 @@ class WorkerExtension:
|
||||
self.device,
|
||||
)
|
||||
|
||||
def update_weight(self, name, dtype, shape):
|
||||
def update_weight(self, name, dtype_name, shape):
|
||||
dtype = getattr(torch, dtype_name)
|
||||
weight = torch.empty(shape, dtype=dtype, device="cuda")
|
||||
self.model_update_group.broadcast(
|
||||
weight, src=0, stream=torch.cuda.current_stream()
|
||||
|
Reference in New Issue
Block a user