mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 16:44:54 +08:00
Add ScalarType argument to Type::options() (#19270)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19270 ghimport-source-id: a5ade6131f3260066c5750ea1fa9ed5c998bb791 Differential Revision: D14938707 Pulled By: li-roy fbshipit-source-id: 018fb3f01706531a06515d6d861e5683a455a705
This commit is contained in:
committed by
Facebook Github Bot
parent
a044ba1af5
commit
ab78449e8c
@ -46,6 +46,7 @@ namespace torch { namespace autograd {
|
||||
VariableInfo::VariableInfo(const Variable& var)
|
||||
: type(&var.dispatch_type())
|
||||
, device(var.device())
|
||||
, scalar_type(var.scalar_type())
|
||||
, size(var.sizes().vec())
|
||||
, requires_grad(var.requires_grad()) {
|
||||
}
|
||||
@ -53,7 +54,7 @@ VariableInfo::VariableInfo(const Variable& var)
|
||||
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
|
||||
// NB: This will NOT work if we ever get mixed device gradients
|
||||
device_guard.reset_device(device);
|
||||
return at::zeros(size, type->options());
|
||||
return at::zeros(size, type->options(scalar_type));
|
||||
}
|
||||
|
||||
auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
|
||||
|
||||
Reference in New Issue
Block a user