mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove unnecessary type dispatches from Variable::Impl ctor (#13630)
Summary: This should improve the performance of wrapping a tensor in a Variable Pull Request resolved: https://github.com/pytorch/pytorch/pull/13630 Reviewed By: ezyang Differential Revision: D12944960 Pulled By: zou3519 fbshipit-source-id: 89fa78a563e46a747d851a90ffd1b5cf3cd2d0d7
This commit is contained in:
committed by
Facebook Github Bot
parent
2ae8e46105
commit
e70321ed9e
@ -22,7 +22,7 @@
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
Variable::Impl::Impl(at::Tensor data, bool requires_grad, Edge gradient_edge)
|
||||
: TensorImpl(data.type().type_id(), data.type().typeMeta(), data.type().allocator(), /* is variable */ true),
|
||||
: TensorImpl(data.type_id(), data.dtype(), /*allocator=*/nullptr, /* is variable */ true),
|
||||
data_(std::move(data)),
|
||||
grad_fn_(std::move(gradient_edge.function)),
|
||||
requires_grad_(false),
|
||||
|
@ -319,7 +319,7 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
|
||||
/// variables.
|
||||
void set_requires_grad(bool requires_grad) override {
|
||||
AT_CHECK(
|
||||
!requires_grad || at::isFloatingType(type().scalarType()),
|
||||
!requires_grad || at::isFloatingType(at::typeMetaToScalarType(dtype())),
|
||||
"Only Tensors of floating point dtype can require gradients");
|
||||
requires_grad_ = requires_grad;
|
||||
}
|
||||
|
Reference in New Issue
Block a user