mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable batch samples in sparse tests (#164677)
The test cases are enabled because the issue was fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164677 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab01a0d7d3
commit
ee5389d520
@ -4913,9 +4913,6 @@ class TestSparseAny(TestCase):
|
||||
lambda i, v, sz: cnstr(i, v, sz, **kwargs_).to_dense(masked_grad=masked),
|
||||
args_, masked=masked)
|
||||
else:
|
||||
if layout in {torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} and 0:
|
||||
# TODO: remove this if-block after gh-107370 is resolved
|
||||
continue
|
||||
torch.autograd.gradcheck(
|
||||
lambda ci, pi, v: cnstr(ci, pi, v, **kwargs).to_dense(masked_grad=masked),
|
||||
args, masked=masked)
|
||||
@ -5494,7 +5491,6 @@ class TestSparseAny(TestCase):
|
||||
layout, device=device, dtype=torch.float64,
|
||||
enable_zero_sized=False, # pinning zero-sized tensors is a no-op
|
||||
pin_memory=True,
|
||||
enable_batch=False, # TODO: remove after gh-104868 is resolved
|
||||
):
|
||||
if layout is torch.sparse_coo:
|
||||
self.assertTrue(t._indices().is_pinned())
|
||||
@ -5524,7 +5520,6 @@ class TestSparseAny(TestCase):
|
||||
layout, device=device, dtype=torch.float64,
|
||||
enable_zero_sized=False, # pinning zero-sized tensors is a no-op
|
||||
pin_memory=False, # no pinning
|
||||
enable_batch=False, # TODO: remove after gh-104868 is resolved
|
||||
):
|
||||
t = t_.pin_memory()
|
||||
self.assertTrue(t.is_pinned())
|
||||
@ -5575,7 +5570,6 @@ class TestSparseAny(TestCase):
|
||||
enable_zero_sized=False, # pinning zero-sized tensors is a no-op
|
||||
pin_memory=None, # constructor does not specify pin_memory=...
|
||||
members_pin_memory=True, # indices and values are pinned
|
||||
enable_batch=False, # TODO: remove after gh-104868 is resolved
|
||||
):
|
||||
if layout is torch.sparse_coo:
|
||||
self.assertTrue(t._indices().is_pinned())
|
||||
@ -5613,7 +5607,6 @@ class TestSparseAny(TestCase):
|
||||
for args, kwargs in self.generate_simple_inputs(
|
||||
layout, device=device, dtype=torch.float64,
|
||||
enable_zero_sized=False, # pinning zero-sized tensors is a no-op
|
||||
enable_batch=False, # TODO: remove after gh-104868 is resolved
|
||||
output_tensor=False):
|
||||
|
||||
# indices are pinned, values is a non-pinned tensor
|
||||
|
Reference in New Issue
Block a user