mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add the function specialization for promote with ITensorListRef (#87756)
Fixes [#87684](https://github.com/pytorch/pytorch/issues/87684) It's due to a new tensor list type is introduced as `ITensorListRef`. We need the function specialization for `prioritize` and `cached_cast` for this new tensor list type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87756 Approved by: https://github.com/jgong5, https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
166b5d3e7c
commit
f150e70ca2
@ -819,6 +819,32 @@ class TestJitTraceAutocast(JitTestCase):
|
||||
continue
|
||||
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
||||
|
||||
def test_cat_promote(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TestModel, self).__init__()
|
||||
|
||||
def forward(self, a, b):
|
||||
return torch.cat([a, b], 0)
|
||||
with torch.jit.fuser("none"):
|
||||
# In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
|
||||
# To avoid the fusion group from TE, we will disable the fuser here.
|
||||
for jit_freeze_or_not in [False, True]:
|
||||
test_model = TestModel().eval()
|
||||
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
|
||||
a = torch.rand(24, 128, 128)
|
||||
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
|
||||
c = test_model(a, b)
|
||||
traced = torch.jit.trace(test_model, (a, b))
|
||||
if jit_freeze_or_not:
|
||||
traced = torch.jit.freeze(traced)
|
||||
for _ in range(3):
|
||||
c2 = traced(a, b)
|
||||
self.assertTrue(c.dtype, torch.float32)
|
||||
self.assertTrue(c2.dtype, torch.float32)
|
||||
traced_graph = traced.graph_for(a, b)
|
||||
self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
|
||||
|
||||
def test_script_autocast_cpu(self):
|
||||
def fn(x):
|
||||
if torch.is_autocast_cpu_enabled():
|
||||
|
Reference in New Issue
Block a user