Add ScalarType argument to Type::options() (#19270)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19270
ghimport-source-id: a5ade6131f3260066c5750ea1fa9ed5c998bb791

Differential Revision: D14938707

Pulled By: li-roy

fbshipit-source-id: 018fb3f01706531a06515d6d861e5683a455a705
This commit is contained in:
Roy Li
2019-04-21 21:12:21 -07:00
committed by Facebook Github Bot
parent a044ba1af5
commit ab78449e8c
18 changed files with 76 additions and 73 deletions

View File

@ -46,6 +46,7 @@ namespace torch { namespace autograd {
VariableInfo::VariableInfo(const Variable& var)
: type(&var.dispatch_type())
, device(var.device())
, scalar_type(var.scalar_type())
, size(var.sizes().vec())
, requires_grad(var.requires_grad()) {
}
@ -53,7 +54,7 @@ VariableInfo::VariableInfo(const Variable& var)
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
// NB: This will NOT work if we ever get mixed device gradients
device_guard.reset_device(device);
return at::zeros(size, type->options());
return at::zeros(size, type->options(scalar_type));
}
auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {