mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add the function specialization for promote with ITensorListRef (#87756)
Fixes [#87684](https://github.com/pytorch/pytorch/issues/87684) It's due to a new tensor list type is introduced as `ITensorListRef`. We need the function specialization for `prioritize` and `cached_cast` for this new tensor list type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87756 Approved by: https://github.com/jgong5, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
166b5d3e7c
commit
f150e70ca2
@ -126,6 +126,16 @@ inline at::ScalarType prioritize(
|
||||
return current;
|
||||
}
|
||||
|
||||
inline at::ScalarType prioritize(
|
||||
at::ScalarType current,
|
||||
const ITensorListRef& list,
|
||||
DeviceType device_type = DeviceType::CUDA) {
|
||||
for (const auto& tensor : list) {
|
||||
current = prioritize(current, tensor, device_type);
|
||||
}
|
||||
return current;
|
||||
}
|
||||
|
||||
// Template to catch non-Tensor args (no-op that returns current best guess)
|
||||
template <typename T>
|
||||
inline at::ScalarType prioritize(
|
||||
@ -196,6 +206,18 @@ inline std::vector<Tensor> cached_cast(
|
||||
return vec;
|
||||
}
|
||||
|
||||
inline std::vector<Tensor> cached_cast(
|
||||
at::ScalarType to_type,
|
||||
const ITensorListRef& arg,
|
||||
DeviceType device_type = DeviceType::CUDA) {
|
||||
std::vector<Tensor> vec;
|
||||
vec.reserve(arg.size());
|
||||
for (const auto& t : arg) {
|
||||
vec.push_back(cached_cast(to_type, t, device_type));
|
||||
}
|
||||
return vec;
|
||||
}
|
||||
|
||||
// Template to catch non-Tensor args.
|
||||
template <typename T>
|
||||
inline T cached_cast(
|
||||
|
@ -819,6 +819,32 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
continue
|
||||
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
def test_cat_promote(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__()
|
||||
|
||||
def forward(self, a, b):
|
||||
return torch.cat([a, b], 0)
|
||||
with torch.jit.fuser("none"):
|
||||
# In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
|
||||
# To avoid the fusion group from TE, we will disable the fuser here.
|
||||
for jit_freeze_or_not in [False, True]:
|
||||
test_model = TestModel().eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
|
||||
a = torch.rand(24, 128, 128)
|
||||
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
|
||||
c = test_model(a, b)
|
||||
traced = torch.jit.trace(test_model, (a, b))
|
||||
if jit_freeze_or_not:
|
||||
traced = torch.jit.freeze(traced)
|
||||
for _ in range(3):
|
||||
c2 = traced(a, b)
|
||||
self.assertTrue(c.dtype, torch.float32)
|
||||
self.assertTrue(c2.dtype, torch.float32)
|
||||
traced_graph = traced.graph_for(a, b)
|
||||
self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
|
||||
|
||||
def test_script_autocast_cpu(self):
|
||||
def fn(x):
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
|
Reference in New Issue
Block a user