mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use THPVariable_Unpack in python_nccl (#56016)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56016 Missed these because I don't build on CUDA Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bdhirsh Differential Revision: D27765124 Pulled By: ezyang fbshipit-source-id: aa202f594659d53c903b88c9d4a4cbb0e1c0b40a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6ec71ed4f9
commit
6c65ce8ee1
@ -288,7 +288,7 @@ at::Tensor extract_tensor(PyObject* obj) {
|
||||
if (!THPVariable_Check(obj)) {
|
||||
throw torch::TypeError("expected Tensor (got %s)", Py_TYPE(obj)->tp_name);
|
||||
}
|
||||
return ((THPVariable*)obj)->cdata;
|
||||
return THPVariable_Unpack(obj);
|
||||
}
|
||||
|
||||
static inline
|
||||
@ -305,8 +305,7 @@ std::vector<at::Tensor> extract_tensors(PyObject* obj) {
|
||||
throw torch::TypeError(
|
||||
"expected Tensor at %d (got %s)", (int)i, Py_TYPE(item)->tp_name);
|
||||
}
|
||||
auto var = (THPVariable*)item;
|
||||
list.emplace_back(var->cdata);
|
||||
list.emplace_back(THPVariable_Unpack(item));
|
||||
}
|
||||
return list;
|
||||
}
|
||||
|
Reference in New Issue
Block a user