add the function specialization for promote with ITensorListRef (#87756)

Fixes [#87684](https://github.com/pytorch/pytorch/issues/87684)
It's due to a new tensor list type is introduced as `ITensorListRef`.  We need the function specialization for `prioritize` and `cached_cast` for this new tensor list type.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/87756
Approved by: https://github.com/jgong5, https://github.com/ezyang
This commit is contained in:
leslie-fang-intel
2022-10-28 10:30:30 +00:00
committed by PyTorch MergeBot
parent 166b5d3e7c
commit f150e70ca2
2 changed files with 48 additions and 0 deletions

View File

@ -126,6 +126,16 @@ inline at::ScalarType prioritize(
return current;
}
inline at::ScalarType prioritize(
at::ScalarType current,
const ITensorListRef& list,
DeviceType device_type = DeviceType::CUDA) {
for (const auto& tensor : list) {
current = prioritize(current, tensor, device_type);
}
return current;
}
// Template to catch non-Tensor args (no-op that returns current best guess)
template <typename T>
inline at::ScalarType prioritize(
@ -196,6 +206,18 @@ inline std::vector<Tensor> cached_cast(
return vec;
}
inline std::vector<Tensor> cached_cast(
at::ScalarType to_type,
const ITensorListRef& arg,
DeviceType device_type = DeviceType::CUDA) {
std::vector<Tensor> vec;
vec.reserve(arg.size());
for (const auto& t : arg) {
vec.push_back(cached_cast(to_type, t, device_type));
}
return vec;
}
// Template to catch non-Tensor args.
template <typename T>
inline T cached_cast(

View File

@ -819,6 +819,32 @@ class TestJitTraceAutocast(JitTestCase):
continue
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_cat_promote(self):
class TestModel(torch.nn.Module):
def __init__(self):
super(TestModel, self).__init__()
def forward(self, a, b):
return torch.cat([a, b], 0)
with torch.jit.fuser("none"):
# In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
# To avoid the fusion group from TE, we will disable the fuser here.
for jit_freeze_or_not in [False, True]:
test_model = TestModel().eval()
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
a = torch.rand(24, 128, 128)
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
c = test_model(a, b)
traced = torch.jit.trace(test_model, (a, b))
if jit_freeze_or_not:
traced = torch.jit.freeze(traced)
for _ in range(3):
c2 = traced(a, b)
self.assertTrue(c.dtype, torch.float32)
self.assertTrue(c2.dtype, torch.float32)
traced_graph = traced.graph_for(a, b)
self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
def test_script_autocast_cpu(self):
def fn(x):
if torch.is_autocast_cpu_enabled():