Compare commits

...

4 Commits

Author SHA1 Message Date
c263bd43e8 [inductor] use triu ref instead of lowering (#96040) (#96462)
Fixes #95958
Generated code is functionally identical with ref and lowering, only minor differences

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96040
Approved by: https://github.com/jansel

Co-authored-by: Natalia Gimelshein <ngimel@fb.com>
2023-03-09 17:42:00 -05:00
c9913cf66f Add jinja2 as mandatory dependency (#95691) (#96450)
Should fix #95671  for nightly wheels issue. v2.0.0 RC does not need this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95691
Approved by: https://github.com/malfet

Co-authored-by: Wei Wang <weiwangmeta@meta.com>
2023-03-09 17:31:12 -05:00
2f7d8bbf17 Fix expired deprecation of comparison dtype for NumPy 1.24+ (#91517) (#96452)
> The `dtype=` argument to comparison ufuncs is now applied correctly. That
> means that only `bool` and `object` are valid values and `dtype=object` is
> enforced.

Source: https://numpy.org/doc/stable/release/1.24.0-notes.html#expired-deprecations

Fixes #91516

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91517
Approved by: https://github.com/zou3519, https://github.com/huydhn

Co-authored-by: Johnson <j3.soon@msa.hinet.net>
2023-03-09 14:30:00 -08:00
ca0cdf52ca dl_open_guard should restore flag even after exception (#96231) (#96457)
I.e. follow pattern outlined in https://docs.python.org/3.8/library/contextlib.html#contextlib.contextmanager

Also, return early on non-unix platforms (when `sys.getdlopenflags` is not defined)

Fixes https://github.com/pytorch/pytorch/issues/96159

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96231
Approved by: https://github.com/atalman

(cherry picked from commit 941ff109d32d51d6e93a2c2f4a028ff3826ece31)
2023-03-09 14:29:17 -08:00
6 changed files with 12 additions and 30 deletions

View File

@ -1024,6 +1024,7 @@ def main():
'typing-extensions',
'sympy',
'networkx',
'jinja2',
]
extras_require = {

View File

@ -448,6 +448,7 @@ inductor_all_samples = {
"mT",
"mH",
"rsub",
"triu",
}

View File

@ -308,6 +308,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.trace,
aten.transpose.int,
aten.tril.default,
aten.triu.default,
aten.unfold,
aten.unfold_backward,
aten.upsample_bilinear2d,

View File

@ -1505,30 +1505,6 @@ def iota(
)
@register_lowering(aten.triu)
def triu(x, diagonal=0):
x_loader = x.make_loader()
dtype = x.get_dtype()
def inner_fn(index):
*_, i, j = index
return ops.where(
ops.ge(
ops.index_expr(j - i - diagonal, torch.int32),
ops.constant(0, torch.int32),
),
x_loader(index),
ops.constant(0, dtype),
)
return Pointwise.create(
device=x.get_device(),
dtype=dtype,
inner_fn=inner_fn,
ranges=list(x.get_size()),
)
@register_lowering(aten.select_scatter, type_promotion_kind=None)
def select_scatter(x, src, dim: int, index: int):
assert x.get_dtype() == src.get_dtype()

View File

@ -22,11 +22,14 @@ def dl_open_guard():
Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a
shared library to load custom operators.
"""
if _SET_GLOBAL_FLAGS:
old_flags = sys.getdlopenflags()
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
yield
if _SET_GLOBAL_FLAGS:
if not _SET_GLOBAL_FLAGS:
yield
return
old_flags = sys.getdlopenflags()
sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL)
try:
yield
finally:
sys.setdlopenflags(old_flags)

View File

@ -380,7 +380,7 @@ def make_histogram(values, bins, max_bins=None):
limits = new_limits
# Find the first and the last bin defining the support of the histogram:
cum_counts = np.cumsum(np.greater(counts, 0, dtype=np.int32))
cum_counts = np.cumsum(np.greater(counts, 0))
start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
start = int(start)
end = int(end) + 1