Related issue: #125077
### Feature
Inductor tries to remove dimensions with stride 0 from block pointers. Rather than loading with stride 0, it's more efficient to load a smaller block pointer, then use `tl.broadcast_to` to broadcast it up to the desired size. This already worked for simpler block pointers, but it was disabled for more complex block pointers which used `tl.reshape` to change the dimensionality after loading.
This PR generalizes the approach to work for all block pointers. The idea is to first reshape, adding singleton dimensions, then broadcast those singletons up to something larger, then reshape again to the final output shape. For readability, we emit this code only if it actually does something. Simpler loads will just have `tl.load`.
Here's an example of a complicated kernel that uses `reshape` -> `load` -> `reshape`. (The first reshape is actually the slice `[None,None,:]`).
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = (xindex // 8)
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tmp2.to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
Before this PR, we would have stride-0 dimensions:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 64
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x2 = xindex
x1 = (xindex // 8)
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.load(tl.make_block_ptr(in_ptr1, shape=[8, 1, 8], strides=[8, 0, 0], block_shape=[((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))], order=[2, 1, 0], offsets=[(xoffset // 8), 0, xoffset % 8]), boundary_check=[0], eviction_policy='evict_last'), [XBLOCK])
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
Here's a simpler example where we use 2D tiling. In this case we don't actually need the broadcast. The broadcast is implied via a slice adding a new singleton dimension. This code is not changed by this PR, but it's important to know that we don't accidentally insert unnecessary broadcasts.
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 8
xnumel = 8
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1])
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[None, :]
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[1, 8], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tmp2.to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')
```
### Test Plan
Added a new expecttest to check the emitted code for broadcast addition. Looking at the test, we can see that stride 0 dimensions are removed. (This test generated the example kernels in the previous section.)
This change also removed a stride-0 dimension in an existing block pointer test. I updated the expected code accordingly.
Bonus: I noticed that the test parametrization for `config.prefer_nd_tiling` wasn't working as intended. It ended up always setting this option to `True`. Fixed it so we get the intended test coverage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135557
Approved by: https://github.com/shunting314, https://github.com/jansel
Co-authored-by: Yueming Hao <yhao@meta.com>
Summary:
1. Move the debug printer call a level lower -> at here
:https://www.internalfb.com/code/fbsource/[931d7bbb9e7cf2dcb926f42718f56fc940903eec]/fbcode/caffe2/torch/_inductor/codegen/cpp_wrapper_cuda.py?lines=335
2. Add UT for validating debug printer for user defined triton kernel codegen
The benefit of having the debug printer call happens at a more centralized place is 1) reduce the duplicate debug printer related logic code scattered everywhere in the codebase 2) it can handle more triton kernel codegen path as long as it invokes this `generate_kernel_call()` for example, it can automatically handle/support user_defined_kernel 's debug printing which is a pretty common use case we encounter in debugging
Test Plan:
```AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=2 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_aoti_debug_printer_user_defined_triton_kernel_abi_compatible_cuda```
Also verified that templateKernel codegen path still works
Differential Revision: D61949020
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134789
Approved by: https://github.com/ColinPeppler
Summary:
Follow up small diff to fix a couple issues:
- add condition for cuda/gpu case to only print kernel name list in the second pass i.e. when we do the cpp wrapper codegen
- other minor fixes around `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT` option
Test Plan:
```
AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT="triton_poi_fused_0" AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+graph, inductor, +schedule, output_code" buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:test_aot_inductor -- -r test_addmm_abi_compatible_cuda
```
Differential Revision: D60954888
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133016
Approved by: https://github.com/ColinPeppler
Fixes#125077
**Feature**
This PR creates a new Inductor config, `config.triton.prefer_nd_tiling`, which is disabled by default. When enabled, this encourages the Triton code to use as many tiling dimensions as possible. This simplifies indexing expressions for discontiguous tensors, resulting in expressions like `5 * x + 8 * y` as opposed to `5 * (x // 7) + 8 * (y % 9)`. This allows us to find more block pointers than we normally would. We should now see simplified indexing expressions as long as:
1. All discontiguous reads/writes have the same shape.
2. The number of discontiguous dimensions is less than `config.triton.max_tiles`.
Here's an example kernel (elementwise add of views) with ND tiling disabled:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 21
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 7
x1 = (xindex // 7)
x2 = xindex
tmp0 = tl.load(in_ptr0 + (x0 + (9*x1)), xmask)
tmp1 = tl.load(in_ptr1 + (x0 + (9*x1)), xmask)
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[21], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])
''', device_str='cuda')
```
And here's the version with it enabled:
```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 3
xnumel = 7
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[None, :]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
x1 = xindex
y0 = yindex
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[7, 3], strides=[1, 9], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1], eviction_policy='evict_last')
tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[7, 3], strides=[1, 9], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), boundary_check=[0, 1], eviction_policy='evict_last')
tmp2 = tmp0 + tmp1
tl.store(tl.make_block_ptr(out_ptr0, shape=[7, 3], strides=[1, 7], block_shape=[XBLOCK, YBLOCK], order=[1, 0], offsets=[xoffset, yoffset]), tl.broadcast_to(tmp2, [XBLOCK, YBLOCK]).to(tl.float32), boundary_check=[0, 1])
''', device_str='cuda')
```
With this feature enabled, we get a discontiguous strided block pointer. Previously, this would only have worked for specific shapes, like powers of 2 or multiples of the maximum block size. With this PR, we can support arbitrary shapes so long as we have enough tiles to cover all discontiguous dimensions.
**Test plan**
This PR adds some tests for pointwise ops with discontiguous tensors.
- Test that we can generate block pointers for views with odd shapes like `(5,7)`, `(9,3,5)`, etc.
- Test that we can generate block pointers for a single discontiguous dim in 3D and 4D tensors.
- Test that we generate a 2D tiling for a 5D tensor with two discontiguous dims. This case doesn't generate a block pointer, but it checks that the output code is at least correct.
This PR also parametrizes some existing tests to run with and without `triton.prefer_nd_tiling`. That way, we ensure this feature doesn't break existing usage.
Since this setting isn't enabled on most tests, I also created https://github.com/pytorch/pytorch/pull/132935 to test what happens when `triton.prefer_nd_tiling=True` by default. None of the failures seem related to invalid tiling, so I think this feature is safe to merge.
**Limitations and follow-ups**
I can see two main improvements which would expand the usefulness of this feature:
1. This feature currently only works for pointwise kernels, since reductions are never tiled. As a follow-up, we could enable tiled reductions to extend these benefits to reduction kernels.
2. The usefulness of this feature depends on `triton.config.max_tiles`. This is currently restricted to 2 by default, although it can be increased to 3 in certain cases. To support more discontiguous dims, we might consider expanding support for 3D tiling, or even supporting ND tiling, by mapping an ND "virtual" launch grid onto Triton's 3D launch grid.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132937
Approved by: https://github.com/jansel, https://github.com/eellison
Summary:
**Context:**
Currently we have a helper to print out AtenTensor in [shim_common.cpp](https://github.com/pytorch/pytorch/blob/v2.4.0-rc4/torch/csrc/inductor/aoti_torch/shim_common.cpp#L866)
The way we were using this function was a “manual” process. We inject this function into the generated output.cpp file, and recompile and reload the file. This diff automates the printing value process.
**Changes:**
1. Added a simple initial debug printer helper to print out tensor values
2. Added a filter option to selectively dump tensor values.
**Usage:**
Sample cmd :
```
AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor, +schedule, output_code" python test/inductor/test_aot_inductor.py -k test_addmm_abi_compatible_cuda
```
Sample outputs :
```
[ before_launch - triton_poi_fused_0 - buf0 ]:
0.6331
1.6358
-0.3459
1.0196
-0.4122
1.4279
[ CUDAFloatType{6} ]
Min value: -0.412198
Max value: 1.63582
Device: cuda:0
Size: [6]
Stride: [1]
Dtype: float
Layout: Strided
Number of elements: 6
Is contiguous: 1
Requires grad: 0
[ after_launch - triton_poi_fused_0 - buf0 ]:
0.6331
1.6358
-0.3459
1.0196
-0.4122
1.4279
[ CUDAFloatType{6} ]
Min value: -0.412198
Max value: 1.63582
Device: cuda:0
Size: [6]
Stride: [1]
Dtype: float
Layout: Strided
Number of elements: 6
Is contiguous: 1
Requires grad: 0
[ before_launch - aoti_torch_cuda_addmm_out - buf1 ]:
Min value: -2.25655
Max value: 2.32996
Device: cuda:0
Size: [16, 6]
Stride: [6, 1]
Dtype: float
Layout: Strided
Number of elements: 96
Is contiguous: 1
Requires grad: 0
[ before_launch - aoti_torch_cuda_addmm_out - buf0 ]:
0.6331
1.6358
-0.3459
1.0196
-0.4122
1.4279
[ CUDAFloatType{6} ]
Min value: -0.412198
Max value: 1.63582
Device: cuda:0
Size: [6]
Stride: [1]
Dtype: float
Layout: Strided
Number of elements: 6
Is contiguous: 1
Requires grad: 0
[ after_launch - aoti_torch_cuda_addmm_out - buf1 ]:
Min value: -12.0839
Max value: 11.6878
Device: cuda:0
Size: [16, 6]
Stride: [6, 1]
Dtype: float
Layout: Strided
Number of elements: 96
Is contiguous: 1
Requires grad: 0
[ after_launch - aoti_torch_cuda_addmm_out - buf0 ]:
0.6331
1.6358
-0.3459
1.0196
-0.4122
1.4279
[ CUDAFloatType{6} ]
Min value: -0.412198
Max value: 1.63582
Device: cuda:0
Size: [6]
Stride: [1]
Dtype: float
Layout: Strided
Number of elements: 6
Is contiguous: 1
Requires grad: 0
stats [('calls_captured', 1), ('unique_graphs', 1)]
inductor [('pattern_matcher_count', 2), ('pattern_matcher_nodes', 2), ('extern_calls', 2)]
.
----------------------------------------------------------------------
Ran 1 test in 10.867s
OK
```
The user is able to filter kernel names to print out values by specifying env var `AOT_INDUCTOR_FILTERED_KERNELS_TO_PRINT` and see choices of kernel names in a log message like below:
```
torch/_inductor/graph.py:1642] Finished codegen for all nodes. The list of kernel names available: ['triton_poi_fused_0', 'aoti_torch_cuda_addmm_out']
```
In the follow-up diff, will add `torch.save()` to dump/save the intermediate tensors into individual `.pt` files that can be further `torch.load()`.
Test Plan:
Run Unit Tests in OSS: (similar cmd as mentioned above in the usage part)
`AOT_INDUCTOR_DEBUG_INTERMEDIATE_VALUE_PRINTER=1 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor, output_code" python test/inductor/test_aot_inductor.py -k test_addmm_abi_compatible_cuda`
Differential Revision: D60538496
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132323
Approved by: https://github.com/ColinPeppler
Summary:
Reland #124969 by backing out D60397377 "Back out "[1/2] PT2 Inductor ComboKernels - Foreach cases (#124969)""
The original diff D54134695 was reverted because of failure of ads nightly cogwheel tests.
The root cause: the logic for generating mask in Triton kernel needed update after a recent refactoring on triton.py. This diff includes the fix of the root cause.
See D54134695 or #124969 for more details.
Test Plan:
Originally failed tests
f585704630
f585733786
Diff patched:
f586664028
f586663820
Differential Revision: D60458597
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132182
Approved by: https://github.com/Yuzhen11
Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.
See, repro here: P1453035092.
Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
Python's set is non deterministic. There is an internal failure which we recently ran into which did not consistently fail.
See, repro here: P1453035092.
Now, with these changes, it does consistently fail. In follow ups we could also consider adding a lintrule for uses of either set() or set literals.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130004
Approved by: https://github.com/oulgen
Persistent kernels are sometimes able to remove intermediate buffers that would
otherwise be needed for the non-persistent reduction kernel. This makes
multi kernel's codegen more complicated as it needs to drop these extra
arguments at runtime after selecting the correct kernel to run.
Instead, this PR updates the persistent kernel's `must_keep_buffers` so these
aren't dropped during codegen so both kernels have the same signature.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127724
Approved by: https://github.com/shunting314
ghstack dependencies: #131044
Currently a buffer represents both a tensor with physical storage and a
computation that produces the tensor as a result.
This PR attempts to split these into two different concepts in the scheduler.
This should allow us to have multiple outputs from a single operation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128893
Approved by: https://github.com/lezcano
This is a short term fix for: https://github.com/pytorch/pytorch/issues/124002
We found the cause of bad perf for the int8_unpack kernel is due to sub-optimal indexing. In this PR we introduce 2 indexing optimizations:
1. expand FloorDiv to the entire expression when feasible. E.g. `x1 * 1024 + x2 // 2` will be transformed to `(x1 * 2048 + x2) // 2`. The motivation is that we have more chance to simplify loops for `x1 * 2048 + x2`.
2. merge ModularIndexing pairs: `ModularIndexing(ModularIndex(x, 1, a), 1, b)`, can be simplified to `ModularIndexing(x, 1, b)` if a is a multiple of b.
With both indexing optimizations, we improve int8_unpack perf by 1.54x (183us -> 119us).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127661
Approved by: https://github.com/jansel
This is a short term fix for: https://github.com/pytorch/pytorch/issues/124002
We found the cause of bad perf for the int8_unpack kernel is due to sub-optimal indexing. In this PR we introduce 2 indexing optimizations:
1. expand FloorDiv to the entire expression when feasible. E.g. `x1 * 1024 + x2 // 2` will be transformed to `(x1 * 2048 + x2) // 2`. The motivation is that we have more chance to simplify loops for `x1 * 2048 + x2`.
2. merge ModularIndexing pairs: `ModularIndexing(ModularIndex(x, 1, a), 1, b)`, can be simplified to `ModularIndexing(x, 1, b)` if a is a multiple of b.
With both indexing optimizations, we improve int8_unpack perf by 1.54x (183us -> 119us).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127661
Approved by: https://github.com/jansel
This pass was broken in a number of ways, as we were not generating
asserts whenever we took it, even though we need to. While doing so,
we found that the analysis we were using for choosing
whether to generate asserts or not for dynamic shapes was completely
broken.
Eliminating indirect indexing in this way allows for a number of optimisations.
In particular, we can now fuse against these kernels (indirect indexing disallows fusions).
The new strategy is as follows:
- We always propagate sympy expressions if we can.
- If an expression was an indirect_indexing, we call `check_bounds`
- We also call `check_bounds` within `CSEProxy.indirect_indexing`
- The checks are issued in the buffer where they would go if the were used in a load
- This makes them always be codegen'd before the load and stores
- In the case of stores, they will be generated potentially much earlier than the stores themselves, which is fine.
We add quite a few asserts to preexisting tests to strengthen them. In particular, we make sure
that issuing an assert plays well with all kinds of C++ vectorisation.
For now, we rely on the logic within `_maybe_evaluate_static` to prove
these bounds. This logic is rather limited though. In the future, we might want
to rely on Z3 here to be able to prove bounds in a more general way.
Supersedes https://github.com/pytorch/pytorch/pull/113068
Fixes https://github.com/pytorch/pytorch/issues/121251
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114471
Approved by: https://github.com/peterbell10
The main motivation for this refactor is that today, when generating templates, this is what happens.
```
def_kernel() # registers hook for fully generating function definition
store_output() # registers hook for generating the output store. *also* keeps a number of things generated on `self.body`.
```
Later on, when we codegen the template: f8c4c268da/torch/_inductor/codegen/simd.py (L1402)
```
epilogue_node.codegen() # Also writes to body!
template.finalize() # Calls the above two hooks for def_kernel and store_output, which then reads from the accumulated `self.body`
```
Today, this is fine, as long as `store_output` is the last function called in the template. However, there's a couple things we probably want to do with kernels that makes this annoying.
1. In FlexAttention backwards, we might want a `modification` to be positioned *after* the `store_output` (just logically from a code organization POV). This doesn't work today because `modification` also needs to codegen a subgraph, but writing to `body` here conflicts with `store_output`'s implicit saved state on `self.body`.
2. If we want to support prologue fusion, we need to go through a bunch of contortions today to call the template hook finalization a couple times (https://github.com/pytorch/pytorch/pull/121211/files#diff-73b89475038a5b4705da805f1217783883fb90398ee1164995db392fc4a342c1R322)
3. The current code also makes it quite difficult to support fusion into multiple output nodes.
To resolve this, I do two things:
1. I *remove* the default `self.body` on `TritonTemplateKernel`. Instead, I have a dict of `self.subgraph_bodies`, which can be enabled in a context with `TritonTemplateKernel.set_subgraph_body`. This allows multiple different template functions to write to their own isolated bodies.
2. I add functions that allow you to finalize specific hooks on `PartialRender`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127144
Approved by: https://github.com/jansel
Add `# mypy: disallow-untyped-defs` to scheduler.py and then fix the resulting fallout.
We probably should eventually add a new node between BaseSchedulerNode and all the non-FusedSchedulerNode types to indicate the split between nodes that have a valid `self.node` and ones that don't. That would cause a lot of the `assert self.node is not None` churn to go away - but was a bigger change because a lot of code makes assumptions about types that aren't reflected in the types themselves.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126656
Approved by: https://github.com/eellison