mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Consistently check 'out' variants against specified dtype/layout/device parameters. (#6973)
We were previously doing this in the most common cases, but not consistently.
This commit is contained in:
@ -73,10 +73,9 @@ PY_VARIABLE_OUT_CHECK_TYPE = CodeTemplate("""\
|
||||
if (r.isNone(${out_idx})) {
|
||||
${call_dispatch}
|
||||
} else {
|
||||
if (!r.isNone(${type_idx})) {
|
||||
check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.layout(${layout_idx}),
|
||||
r.device(${device_idx}), r.isNone(${device_idx}));
|
||||
}
|
||||
check_out_type_matches(r.tensor(${out_idx}), r.scalartype(${type_idx}), r.isNone(${type_idx}),
|
||||
r.layout(${layout_idx}), r.isNone(${layout_idx}),
|
||||
r.device(${device_idx}), r.isNone(${device_idx}));
|
||||
${call_dispatch_out}
|
||||
}
|
||||
""")
|
||||
|
||||
Reference in New Issue
Block a user