Previously, parametrized tests with class arguments, for example
```
@parametrize("this_cls", (Foo, Bar))
```
would create parametrized tests with names `test_foo_this_cls0` and `test_foo_this_cls1`. With this change, we instead should get `test_foo_this_cls_Foo` and `test_foo_this_cls_Bar`
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133546
Approved by: https://github.com/eellison
This PR
* makes changes to the workflow files and scripts so we can run CI workflows on the MI300 runners
* skips and fixes several tests, failed on MI300, observed in https://github.com/pytorch/pytorch/pull/140989
Skipped due to unsupported Float8_e4m3fn data type on MI300 (need to update test code to use datatypes supported by MI300):
- distributed.tensor.parallel.test_micro_pipeline_tp.py::MicroPipelineTPTest::test_fuse_all_gather_scaled_matmul_A_dims_\*_gather_dim_\* (24 tests across inductor/distributed configs)
- distributed.tensor.parallel.test_micro_pipeline_tp.py::test_fuse_scaled_matmul_reduce_scatter_A_dims_\*_scatter_dim_\* (12 tests across inductor/distributed configs))
- inductor.test_loop_ordering::LoopOrderingTest::test_fp8_cast_and_t
- inductor.test_loop_ordering::LoopOrderingTest::test_fp8_pattern_2
Skipped due to AssertionError on MI300:
- inductor.test_mkldnn_pattern_matcher.py::test_qconv2d_int8_mixed_bf16
- distributed._tools.test_sac_ilp::TestSACILP::test_sac_ilp_case1
Skipped:
- test_cuda.py::TestCudaMallocAsync::test_clock_speed
- test_cuda.py::TestCudaMallocAsync::test_power_draw
- test_torch.py::TestTorchDeviceTypeCUDA::test_deterministic_cumsum_cuda
Skipped flaky tests on MI300:
- distributed.test_c10d_gloo.py::ProcessGroupGlooTest::test_gather_stress_cuda
- inductor.test_cpu_repro::CPUReproTests::test_lstm_packed_unbatched_False* (256 tests)
Fixed:
- test_matmul_cuda.py::TestFP8MatmulCudaCUDA::test_float8_basics_cuda
Features:
- inductor/test_fp8.py - declare a new function to convert FP8 datatypes to ROCm supported FP8 datatypes. It keeps test names for CUDA and ROCm and allows to enable Inductor FP8 tests on CPU
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143673
Approved by: https://github.com/jeffdaily, https://github.com/malfet, https://github.com/pruthvistony
Co-authored-by: saienduri <saimanas.enduri@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
See #144006
```py
__________________________________________ CudaReproTests.test_repeated_masked_load __________________________________________
RuntimeError: First class dim doesn't work with python 3.12
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/jansel/conda/envs/pytorch/lib/python3.12/unittest/case.py", line 58, in testPartExecutor
yield
File "/home/jansel/conda/envs/pytorch/lib/python3.12/unittest/case.py", line 634, in run
self._callTestMethod(testMethod)
File "/home/jansel/conda/envs/pytorch/lib/python3.12/unittest/case.py", line 589, in _callTestMethod
if method() is not None:
^^^^^^^^
File "/home/jansel/pytorch/torch/testing/_internal/common_utils.py", line 3108, in wrapper
method(*args, **kwargs)
File "/home/jansel/pytorch/test/inductor/test_cuda_repro.py", line 1678, in test_repeated_masked_load
from functorch.einops import rearrange
File "/home/jansel/pytorch/functorch/einops/__init__.py", line 1, in <module>
from .rearrange import rearrange
File "/home/jansel/pytorch/functorch/einops/rearrange.py", line 7, in <module>
from functorch._C import dim as _C
ImportError: initialization failed
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144006
Approved by: https://github.com/Skylion007
Changes by apply order:
1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`.
2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`.
3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first.
`.parent{...}.absolute()` -> `.absolute().parent{...}`
4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.)
`.parent.parent.parent.parent` -> `.parents[3]`
5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~
~`.parents[3]` -> `.parents[4 - 1]`~
6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374
Approved by: https://github.com/justinchuby, https://github.com/malfet
Changes:
1. Bump `ruff` from 0.7.4 to 0.8.4
2. Change `%`-formatted strings to f-string
3. Change arguments with the `__`-prefix to positional-only arguments with the `/` separator in function signature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143753
Approved by: https://github.com/Skylion007
Changes by apply order:
1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`.
2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`.
3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first.
`.parent{...}.absolute()` -> `.absolute().parent{...}`
4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.)
`.parent.parent.parent.parent` -> `.parents[3]`
5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~
~`.parents[3]` -> `.parents[4 - 1]`~
6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374
Approved by: https://github.com/justinchuby, https://github.com/malfet
# Motivation
This PR enables the XPU quantized convolution. The operators it registers are `onednn::qconv_prepack`, `onednn::qconv1d_pointwise`, `onednn::qconv2d_pointwise`, `onednn::qconv3d_pointwise`. We share same operator schemas as Intel CPU backend as both would call kernels implemented in oneDNN library.
# Details
The implemented operators would be further integrated into pt2e quant flow. In this PR, we validated the kernel functionality via the UT in `test/inductor/test_mkldnn_pattern_matcher.py` where CPU backend defines a series of UT for quantized convolution. Also, we extend the device support for inductor lowering pass and inductor IR defined in `torch/_inductor/fx_passes/quantization.py` and `torch/_inductor/mkldnn_ir.py`. The overall picture would be that, CPU and GPU backend could share the general optimization pass(op fusion) and quantization inductor IR. After lowering, the final kernel would be dispatched to different implementation in oneDNN library.
In this PR, we share the same int8 quantizer in CPU, namely, `X68InductorQuantizer`. In next PR #139578, we will append a `XPUIndcutorQuantizer` which will customized the pt2e behaviors at XPU backend. The capability of `XPUInductorQuantizer` would gradually grow along with the development of quantized operators in XPU.
# Validation
* UT testing
```bash
python test/inductor/test_mkldnn_pattern_matcher.py -v \
-k test_qconv2d_xpu \
-k test_qconv2d_silu_xpu \
-k test_qconv2d_relu6_xpu \
-k test_qconv2d_hardtanh_xpu \
-k test_qconv2d_hardswish_xpu
```
* Runtime exemplification
```bash
#qconv2d
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_u8::blocked:acdb::f0 wei_s8::blocked:acdb::f0 bia_undef::undef::: dst_f32::blocked:acdb::f0,attr-scratchpad:user attr-scales:src0:0:f32+wei:1:f32 attr-zero-points:src0:0:s32 attr-post-ops:binary_add:f32:2+eltwise_linear:1,alg:convolution_direct,mb1_ic128oc128_ih6oh4kh3sh1dh0ph0_iw6ow4kw3sw1dw0pw0,0.0668945
#qconv2d_silu
onednn_verbose,primitive,exec,gpu:0,convolution,jit:ir,forward_training,src_u8::blocked:acdb::f0 wei_s8::blocked:acdb::f0 bia_undef::undef::: dst_u8::blocked:acdb::f0,attr-scratchpad:user attr-scales:src0:0:f32+wei:1:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_swish:1+binary_add:f32:2+eltwise_linear:0.0124779:22,alg:convolution_direct,mb1_ic3oc128_ih8oh6kh3sh1dh0ph0_iw8ow6kw3sw1dw0pw0,0.0881348
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133080
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/atalman
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
### Background
This PR adds the functionality to xfail / skip on a per-`SampleInput` basis for `OpInfo` tests. See #89354 and #82669 for some requests asking for this type of functionality.
This was originally landed for NJT in #138370 and is generalized and slightly tweaked here.
### Design
#### Principles
* Clean separation among `SampleInput` generation logic, test logic that uses the `SampleInput`s, and xfail / skip logic (which will change as bugs are addressed).
* Flexibility in xfail / skip predicate specification - ideally each bug can be handled by a single skip / xfail, even if it surfaces across a specific class of ops.
* This is important in practice for NJT, where it's common to have a bug that affects all binary ops, for example.
* Opt-in with minimal test logic changes + no substantial impact on other tests.
#### Details
The core new concept is a `SampleRule`, which can be either an `XFailRule` or `SkipRule`.
```python
@dataclass
class SampleRule(ABC):
# function to indicate whether the rule applies to this op; return True if so
# NB: str arg of callable is device_type
op_match_fn: Callable[[str, OpInfo], bool] = None
# function to indicate whether the rule applies to this sample; return True if so
sample_match_fn: Callable[[torch.device, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
@dataclass
class XFailRule(SampleRule):
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
@dataclass
class SkipRule(SampleRule):
...
```
* See below for example usage details, but at a high level: each test should have a corresponding list of `sample_skips_and_xfails`.
* The list of `sample_skips_and_xfails` is traversed in order, and the first rule that matches (if any) is applied, so order can matter.
* The PR includes a logging mechanism for matched rules accessible by setting the loglevel to `DEBUG`.
* The split between `op_match_fn` and `sample_match_fn` is made to allow pre-filtering of the list of rules to get only those that apply to the op under test.
* Each `SampleInput` is run within a subtest context so they can be individually skipped / xfailed as needed. This also means that a test will no longer stop after the first erroring `SampleInput`; all samples will be run through test logic.
### Example Usage
Consider the following OpInfo test:
```python
class MyTestCase(TestCase):
@ops(op_db)
def test_foo(self, device, dtype, op):
for sample in op.sample_inputs(device, dtype, requires_grad=False):
# do some SampleInput-based test logic
output = op.op(sample.input, *sample.args, **sample.kwargs)
...
```
This is a common pattern for such tests; simply generate a list of `SampleInputs` and run them through the op. Now say you want to xfail one of these `SampleInput`s for a given op. Today, you have to xfail the entire test or hack around this in the test logic.
This PR lets you do this to get very flexible xfail / skips based on op / sample input properties:
```python
# NB: Define rules for per-SampleInput xfails / skips. These can also be defined in-line in the @ops decorator, but
# it can be more readable to maintain these somewhere else. These are attempted to be matched in order and
# the first one that matches applies, so order can matter.
FOO_SKIPS_AND_XFAILS = [
XFailRule(
error_type=ValueError,
error_mg="2D inputs not supported",
op_match_fn=lambda device, op: (
# NB: logic for which ops this rule applies to goes here
op.full_name == "add"
),
sample_match_fn=lambda device, sample: (
# NB: logic which samples this rule applies to goes here
sample.input.dim() == 2
),
# NB: optional rule identifier can help with debugging matched rules
name="add_with_2D_inputs_not_supported",
),
# NB: This follows a similar structure as XFailRule but without error_type / error_msg. Obviously
# this skips a particular SampleInput instead of xfailing :)
SkipRule(...),
...
]
class MyTestCase(TestCase):
@ops(op_db)
@sample_skips_and_xfails(FOO_SKIPS_AND_XFAILS)
# NB: the @ops decorator automatically filters out any rules that don't apply to this op
def test_foo(self, device, dtype, op):
for sample, subtest_ctx in op.sample_inputs(
# NB: use_subtests=True is required for skips / xfails to work. If skips / xfails are defined and use_subtests != True,
# an informative error will be thrown.
device, dtype, requires_grad=False, use_subtests=True
):
# NB: this subtest context manager runs each sample input as a "subtest" and handles skips / xfails appropriately
with subtest_ctx(self):
# do some SampleInput-based test logic
output = op.op(sample.input, *sample.args, **sample.kwargs)
...
```
More examples can be seen in `test/test_nestedtensor.py`, where this system is used in practice.
I also demonstrate usage of syntactic sugar over this system in `test/functorch/test_vmap.py`. Here, a skip for the `to()` operator is replaced with a granular xfail for `test_vmap_exhaustive()`:
```python
...
# pre-existing xfail
xfail("item"),
# new granular xfail using syntactic sugar over the general system
xfailIf(
"to",
lambda sample: (
sample.kwargs["memory_format"] == torch.channels_last
),
),
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140443
Approved by: https://github.com/janeyx99, https://github.com/zou3519
ghstack dependencies: #140160, #138370
Per discussion with @malfet, only allow weights_only unpickler to load NJT if `torch.nested` and `torch._dynamo` are imported
(this is slightly weird as technically `torch.nested` is actually imported by default and `torch._dynamo.decorators._DimRange` is actually what needs to be imported)
we can't import this from `torch.nested` as this would
- undo dynamo lazy import
- cause circular import
===========================
Redo of https://github.com/pytorch/pytorch/pull/140304 caused issues as `torch.nested._internal.foo` needs to be imported, which causes issues like
```python
torch/_weights_only_unpickler.py", line 339, in load
if full_path in _get_allowed_globals():
torch/_weights_only_unpickler.py", line 188, in _get_allowed_globals
torch.nested._internal.nested_tensor.NestedTensor
AttributeError: module 'torch.nested' has no attribute '_internal'
```
**This likely wasn't caught in our CI because imports are global during unit tests(?), so we use subprocess to properly test this time**
Differential Revision: [D65961691](https://our.internmc.facebook.com/intern/diff/D65961691)
@jbschlosser
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140739
Approved by: https://github.com/malfet
Functionally two decorators are very similar, but one should rely on expectedFailure as much as possible to get signal when something is fixed.
- Move `product_version` variable from `test_mps` to common_utils, but call it `MACOS_VERSION`
- Introduce `skipIfMPSOnMacOS13` to decorate the hard crashes that happens only on MacOS13 (which at this point will not get any fixes and will be deprecated soon)
- Add `device_type='mps'` to all `skipIfMPS` per https://github.com/pytorch/pytorch/issues/140560
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139940
Approved by: https://github.com/janeyx99, https://github.com/huydhn
@ezyang noticed this exercises a multithreading bug that is causing tests to become disabled:
```
2024-11-13T21:05:55.8363582Z inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_fft_ihfftn_cpu_int32 /opt/conda/envs/py_3.9/lib/python3.9/site-packages/_pytest/threadexception.py:73: PytestUnhandledThreadExceptionWarning: Exception in thread Thread-3
2024-11-13T21:05:55.8364857Z
2024-11-13T21:05:55.8364974Z Traceback (most recent call last):
2024-11-13T21:05:55.8365491Z File "/opt/conda/envs/py_3.9/lib/python3.9/threading.py", line 980, in _bootstrap_inner
2024-11-13T21:05:55.8366003Z self.run()
2024-11-13T21:05:55.8366371Z File "/opt/conda/envs/py_3.9/lib/python3.9/threading.py", line 917, in run
2024-11-13T21:05:55.8366858Z self._target(*self._args, **self._kwargs)
2024-11-13T21:05:55.8367518Z File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fbscribelogger/__init__.py", line 176, in _run_event_loop
2024-11-13T21:05:55.8368189Z self.loop.run_until_complete(self.task)
2024-11-13T21:05:55.8368774Z File "/opt/conda/envs/py_3.9/lib/python3.9/asyncio/base_events.py", line 647, in run_until_complete
2024-11-13T21:05:55.8369348Z return future.result()
2024-11-13T21:05:55.8369980Z File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fbscribelogger/__init__.py", line 214, in _worker
2024-11-13T21:05:55.8370603Z message = await asyncio.wait_for(
2024-11-13T21:05:55.8371090Z File "/opt/conda/envs/py_3.9/lib/python3.9/asyncio/tasks.py", line 442, in wait_for
2024-11-13T21:05:55.8371573Z return await fut
2024-11-13T21:05:55.8372156Z File "/opt/conda/envs/py_3.9/lib/python3.9/asyncio/queues.py", line 166, in get
2024-11-13T21:05:55.8372613Z await getter
2024-11-13T21:05:55.8374010Z RuntimeError: Task <Task pending name='Task-1' coro=<FbScribeLogger._worker() running at /opt/conda/envs/py_3.9/lib/python3.9/site-packages/fbscribelogger/__init__.py:214> cb=[_run_until_complete_cb() at /opt/conda/envs/py_3.9/lib/python3.9/asyncio/base_events.py:184]> got Future <Future pending> attached to a different loop
2024-11-13T21:05:55.8375366Z
2024-11-13T21:05:55.8375603Z warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140717
Approved by: https://github.com/ezyang, https://github.com/zxiiro
This PR updates OpInfo-based tests for NJTs:
* Adds extensive coverage across non-contiguous NJTs (both non-contiguous transposed and non-contiguous with holes)
* The `_sample_njts()` helper that `sample_input_func`s utilize now produces non-contig NJTs as well
* Utilizes a `SampleInput`-based xfail system for granular classification of bugs. For example, it's possible to indicate that a class of ops is expected to fail only on non-contig with holes NJT inputs.
* I decided on adding `SampleInput`s and utilizing this system over using test parametrization for two reasons:
* Test perf - adding `SampleInput`s is faster than generating entire new tests
* Avoiding the possibility of `sample_input_func`s not respecting the non-contig test parameter - this would result in silently incorrect passing of these tests. Keeping the responsibility for `SampleInput` generation firmly within each `OpInfo`'s `sample_input_func` means weirdness like this isn't possible
* Improves `SampleInput` naming for a bunch of `sample_input_func`s. This makes it easier to xfail them as needed. For example, binary / unary / other ops now use the new `_describe_njt()` helper to get a string repr that uniquely defines the type of NJT being passed to the op
* Adds appropriate `XFailRule`s to get tests passing for forward / backward / forward compile / backward compile. In general, each xfail corresponds to some bug that needs to be fixed
```python
# Represents a rule indicating how to xfail a particular test. It allows granularity
# at the device, dtype, op, and individual sample levels. This flexibility allows entire
# bugs to be represented by a single rule, even if this corresponds with multiple conceptual
# test cases across multiple ops.
@dataclass
class XFailRule:
# expected error type
error_type: TypeVar = Exception
# expected error message
error_msg: str = ".*"
# function to indicate whether the rule applies; return True if so
match_fn: Callable[[torch.device, torch.dtype, OpInfo, SampleInput], bool] = None
# optional name for identifying the rule
name: str = ""
def match(self, device, dtype, op, sample) -> bool:
return self.match_fn(device, dtype, op, sample)
```
Example:
```python
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
# tensor with 1 in ragged dim.
XFailRule(
error_type=RuntimeError,
error_msg="cannot call binary pointwise function .* with inputs of shapes",
match_fn=lambda device, dtype, op, sample: (
isinstance(op, BinaryUfuncInfo)
and "noncontig_holes" in sample.name
and "broadcasting 1 over ragged" in sample.name
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138370
Approved by: https://github.com/cpuhrsch, https://github.com/soulitzer
ghstack dependencies: #140160
**Background:** The `@parametrize` decorator enjoys widespread usage as a convenient tool for ensuring extensive test coverage. One particular feature that makes this easy is the ability to stack such decorators, testing over the cross-product of inputs. Example:
```python
class MyTestClass(TestCase):
@parametrize("x", range(3))
@parametrize("y", [False, True])
def test_foo(self, x, y):
# Invoked with:
# x=0, y=False
# x=1, y=False
# x=2, y=False
# x=0, y=True
# x=1, y=True
# x=2, y=True
...
```
Note that the `@ops` and `@modules` decorators employ the same underlying machinery for parametrizing over `OpInfo` / `ModuleInfo` entries. These decorators also parametrize over op-specific `device` / `dtype` info *according to what is supported for each op*.
```python
class MyTestClass(TestCase):
@ops(op_db)
def test_foo(self, op, device, dtype):
# Invoked each OpInfo in the db along with each device / dtype that corresponds
# with this op according to the OpInfo entry.
...
```
Note that this in contrast to the naive cross product between ops and devices / dtypes, which would generate too many tests. Certain use cases benefit from a similar type of flexible parametrization that is more intelligent than simple cross-product composition. It is expensive to generate / run too many tests, even if the unneeded ones are skipped appropriately.
This PR attempts to generalize such flexible parametrization and satisfy these use cases through the introduction of a `@reparametrize` decorator, which operates on an existing parametrizer and allows for customized on-the-fly parametrization through the use of an `adapter_fn`. Examples:
```python
# adapter_fn that adds a new arg
def include_is_even_arg(test_name, param_kwargs):
x = param_kwargs["x"]
is_even = x % 2 == 0
new_param_kwargs = dict(param_kwargs)
new_param_kwargs["is_even"] = is_even
is_even_suffix = "_even" if is_even else "_odd"
new_test_name = f"{test_name}{is_even_suffix}"
yield (new_test_name, new_param_kwargs)
# adapter_fn that excludes certain values
def exclude_odds(test_name, param_kwargs):
x = param_kwargs["x"]
is_even = x % 2 == 0
yield None if not is_even else (test_name, param_kwargs)
class MyTestClass(TestCase):
@reparametrize(parametrize("x", range(5)), include_is_even_arg)
def test_foo(self, x, is_even):
# Invoked with both the x value and the new is_even arg
...
@reparametrize(parametrize("x", range(5)), exclude_odds)
def test_bar(self, x):
# Only invoked with even x values
...
```
For a more real-world use case, imagine you want to write a set of OpInfo tests that parametrize over additional op-specific things beyond `device` / `dtype` (in NJT's case, this includes contiguity type, whether to operate over the batch / ragged / other dims, etc.). The `@reparametrize` decorator allows you to customize the `@ops` parametrization to add in these additional args as they make sense on a per-op basis.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138369
Approved by: https://github.com/janeyx99
This has the benefit that
1) It's much easier to aggregate test failure repros into say a CSV or shell script from scuba
2) We can do analysis (eg. set different two sets of tests across two PRs)
3) We can get results faster at the test-level granularity instead of job-level granularity we see in the HUD/GH.
I tested this by introducing a breaking change, adding ci-scribe label and then verifying that the failed tests were logged to scuba: https://fburl.com/scuba/torch_open_source_signpost/w6qt7qr9
I then reverted the breaking change and published this PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138394
Approved by: https://github.com/ezyang