Several improvements for skipfiles:
* Add ```FUNC_INLINELIST``` to support function level skip/inline check.
* Use ```fn.__code__``` to match function since we can't get the function object sometimes.
* Use python module string name for ```FILE_INLINELIST``` and ```SUBMODULE_INLINELIST```.
* Use filename to match file and python module, which can fundamentally resolved the circular import issues introduced by skipfiles.
* Use ```TYPE_CHECKING``` to ensure the python module string name is correct.
* Add unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110835
Approved by: https://github.com/ezyang
Several improvements for skipfiles:
* Add ```FUNC_INLINELIST``` to support function level skip/inline check.
* Use ```fn.__code__``` to match function since we can't get the function object sometimes.
* Use python module string name for ```FILE_INLINELIST``` and ```SUBMODULE_INLINELIST```.
* Use filename to match file and python module, which can fundamentally resolved the circular import issues introduced by skipfiles.
* Use ```TYPE_CHECKING``` to ensure the python module string name is correct.
* Add unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110835
Approved by: https://github.com/ezyang
In this PR:
- Adds support for strides for jagged tensor (design doc for this coming soon)
- NestedTensor skips automatic dynamic
- Make use of @bdhirsh's subclass fakification logic by adding the __tensor_{un,}flatten__ functions.
- Additional logic for fakification: since existing subclass fakification logic does not handle the case where the outer tensor has an additional dimension. We insert one-off logic to (1) insert an extra SingletonSymInt onto the fakified NestedTensor. (2) make sure we call track_symint on both the sizes on the inner and outer tensor during guard creation.
Remaining things that are weird:
- Still need to skip some logic in meta utils for some reason (I was going to write this up more, but decided not to since we're not able to do this anyway for a immediate reason: we cannot arbitrarily compare singleton ints. For now I'm just following Brian's advise from [here](https://github.com/pytorch/pytorch/pull/109171#discussion_r1328137070) )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109171
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
Our experience using `constraints` / `dynamic_dim` with the existing export API has found it to be (subjectively) clunky and (objectively) verbose in common cases.
This PR implements a new design for the export API that replaces the use of `constraints` / `dynamic_dim` with a new way of specifying dynamic shapes, involving the following concepts:
* a constructor `Dim` for first-class named dynamic dimensions with ranges (similar to `functorch.dim`, and analogous to internal symbolic sizes)
* a mechanism that uses the above in `export` calls to associate inputs to their dynamic shape specifications (`dynamic_shapes`)
Design doc: https://docs.google.com/presentation/d/168U7XK72C_WSsZpGESP6Cho9udh193fi0gfjxCNcJ4E/edit#slide=id.p (Meta-only). Note that we only implement Option 1 in that doc. An older version of this PR also implemented Option 3, which is an alternative way of specifying dynamic shapes using tensor type annotations on the exported callable; but we have moved that to future work for now.
See docs for these new features in `torch.export`. The existing `torch.export.export` is modified to use the new API, `torch._export.export__RC__`, whenever `constraints=None`. We have not deprecated the existing API yet, but will do in a follow-up.
Constraint violation errors arising through use of the new API will now contain suggested fixes using the new API. No longer do we need to report all specializations for static dimensions and suggest all constraints over dynamic dimensions to fix such errors. Instead, due to the redesign, the suggested fixes are much more concise, only involving modifying the definitions of relevant `Dim`s.
Differential Revision: [D48919204](https://our.internmc.facebook.com/intern/diff/D48919204/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108448
Approved by: https://github.com/suo, https://github.com/gmagogsfm
Fix: #107315
This PR enables dynamo to trace through the `pytree` API by inlining its functions. In
order to do so, a few details of `pytree` had to be changed.
In summary, this PR:
- Introduces `TreeSpecVariable` for representing `TreeSpec` instances
- Specializes `<type>.__bases__` call, returning a `TupleVariable`
- Enables the call to `id` builtin function for every variable that implements
`as_python_constant` method
- Specializes `ConstantVariable.call_method` for its (un)flatten functions
- Implements `UserDefinedObjectVariable.as_python_constant`
- Modifies `pytree` by:
- Make `SUPPORTED_NODES` a map of ids (instead of types) to `NodeDef`
- Removed `functools.wraps` function, since it can't be inlined
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108533
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
ghstack dependencies: #109201
Summary:
The basic concept behind this diff is to modify Dynamo's tracing behavior when it encounters a KeyedJaggedTensor that is synced (aka has `_length_per_key` and `_offset_per_key` populated). These fields are lists of integers; ordinarily, Dynamo will optimistically try to specialize on integers, however, for KJTs, we know that these integers will definitely vary from run-to-run. Furthermore, ordinarily, we would also specialize these integers if they are 0/1, but we will frequently expect features in KJTs to be 0/1.
The fix is to detect KJTs and treat these integers as *unbacked integers*. This is NOT a universally sound optimization: when treating these integers as unbacked, we never report them as equal to zero or one. In return, we always generate graphs that generalize no matter the length of values on features. This is enough to trace through APS sparse arch, torchrec_dlrm and some small split-cat examples.
The special integer behavior is triggered by a dynamically scoped `force_unspec_int_unbacked_size_like` variable on TracingContext, which we trigger when we wrap a KJT. There probably are other ways to do this, but this was simple and worked.
Test Plan:
```
buck2 test mode/dev-nosan //pytorch/benchmark/fb/test_gpu:run_test_gpu
```
from aakhundov
1. first build feed_lower_benchmark:
```
buck2 build --show-output mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true hpc/new/models/feed/benchmark:feed_lower_benchmark
```
2. then run the lowering of the model with it:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_LOGS="output_code,graph_code" TORCH_COMPILE_DEBUG=1 ../buck-out/v2/gen/fbcode/79c6b019ee0f9469/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/960999465/60/gpu_lowering/input.predictor --skip-trt --skip-ait --sync-mode=0 --enable-aot-inductor --lower-presets="ig_stories" --gpu-trace
```
cf https://docs.google.com/document/d/1yD30xYrdmM8r2HTdmXnZTg0-MHVexfVrAa0294m1AUE/edit?pli=1#heading=h.qiv3fp7e6zg0
From torchrec: https://www.internalfb.com/intern/wiki/Torchrec/Development/Testing_production_models/
From ge0405
baseline (without your diff): f477293168
your diff: f477292363
```
buck2 test //caffe2/test/dynamo:test_dynamo_torchrec
buck2 run 'fbcode//mode/opt' fbcode//pytorch/benchmark/fb/test_gpu:run_test_gpu -- 'pytorch.benchmark.fb.test_gpu.test_gpu.TestBenchmarkFbGpu.test_train_blue_reels_vdd_v3_inductor_speedup'
```
Differential Revision: D49236757
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109216
Approved by: https://github.com/voznesenskym
Before this PR, if we run the following code:
```python
def true_fn(x):
return x - x.cos()
def false_fn(x):
return x + x.sin()
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(3, 4))
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
```
we'll have the following error:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/make_fx.py", line 16, in <module>
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 841, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
File "/home/yidi/local/pytorch/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 461, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/_symbolic_trace.py", line 817, in trace
(self.create_arg(fn(*args)),),
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 497, in wrapped
out = f(*tensors)
File "/home/yidi/local/pytorch/make_fx.py", line 13, in foo
return control_flow.cond(x.shape[0] == 4, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 151, in cond
return torch.compile(cond_op, backend="eager", fullgraph=True)(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 545, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 561, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 483, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in transform
tracer = InstructionTranslator(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2032, in __init__
self.symbolic_locals = collections.OrderedDict(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2035, in <genexpr>
VariableBuilder(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
output = [
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
output = [
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1040, in wrap_tensor
tensor_variable = wrap_fx_proxy(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1267, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1382, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1652, in wrap_to_fake_tensor_and_record
dynamic_dims, constraint_dims = _automatic_dynamic(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1550, in _automatic_dynamic
if dim is not None and e.size()[i] != dim:
File "/home/yidi/local/pytorch/torch/__init__.py", line 352, in __bool__
return self.node.bool_()
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1019, in bool_
return self.guard_bool("", 0)
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1001, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/home/yidi/local/pytorch/torch/fx/experimental/recording.py", line 227, in wrapper
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3793, in evaluate_expr
assert orig_expr == hint, f"{orig_expr} != {hint}"
AssertionError: False != True
from user code:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
It's because we record the SymInt in the frame state in _automatic_dynamic the first time we compile the function. Then In the second time, when we are given a symint sized input with different hints, the comparison fails.
Implementation:
This PR returns shape dynamism according to the dynamism of inputs: if a diemsion is SymInt, return DYNAMIC else return static.
Test Plan:
Add a test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109331
Approved by: https://github.com/ezyang
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
The strategy in this PR is pretty straightforward.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
The plan:
**For tensors w/ a source:**
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do.
**For tensors w/o a source:**
Graph break - we will support this in a subsequent PR
**Handles:**
An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903
Approved by: https://github.com/ezyang
ghstack dependencies: #108846, #109092
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
Summary:
The basic concept behind this diff is to modify Dynamo's tracing behavior when it encounters a KeyedJaggedTensor that is synced (aka has `_length_per_key` and `_offset_per_key` populated). These fields are lists of integers; ordinarily, Dynamo will optimistically try to specialize on integers, however, for KJTs, we know that these integers will definitely vary from run-to-run. Furthermore, ordinarily, we would also specialize these integers if they are 0/1, but we will frequently expect features in KJTs to be 0/1.
The fix is to detect KJTs and treat these integers as *unbacked integers*. This is NOT a universally sound optimization: when treating these integers as unbacked, we never report them as equal to zero or one. In return, we always generate graphs that generalize no matter the length of values on features. This is enough to trace through APS sparse arch, torchrec_dlrm and some small split-cat examples.
The special integer behavior is triggered by a dynamically scoped `force_unspec_int_unbacked_size_like` variable on TracingContext, which we trigger when we wrap a KJT. There probably are other ways to do this, but this was simple and worked.
Test Plan:
```
buck2 test mode/dev-nosan //pytorch/benchmark/fb/test_gpu:run_test_gpu
```
from aakhundov
1. first build feed_lower_benchmark:
```
buck2 build --show-output mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true hpc/new/models/feed/benchmark:feed_lower_benchmark
```
2. then run the lowering of the model with it:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_LOGS="output_code,graph_code" TORCH_COMPILE_DEBUG=1 ../buck-out/v2/gen/fbcode/79c6b019ee0f9469/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/960999465/60/gpu_lowering/input.predictor --skip-trt --skip-ait --sync-mode=0 --enable-aot-inductor --lower-presets="ig_stories" --gpu-trace
```
cf https://docs.google.com/document/d/1yD30xYrdmM8r2HTdmXnZTg0-MHVexfVrAa0294m1AUE/edit?pli=1#heading=h.qiv3fp7e6zg0
From torchrec: https://www.internalfb.com/intern/wiki/Torchrec/Development/Testing_production_models/
From ge0405
baseline (without your diff): f477293168
your diff: f477292363
Differential Revision: D49019987
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108960
Approved by: https://github.com/voznesenskym
The strategy for supporting functools partials is relatively straightforward.
There are 2 cases we need to support:
**1) Functools partials as input**
In this case, we are first seeing the functools partial and it is guaranteed to have a source. As such, the args, keywords, and func of the functools partial are passed through VariableBuilder. As this is the first time we are seeing these objects (as it is an input), we re-enter VariableBuilder with a source referencing the args, keywords, and func as attributes of the input to produce:
- func: A callable VariableTracker (UDF, TorchVariable, etc) depending on the value of `func`
- args: List[VariableTracker] - note, not ListVariableTracker!
- keywords: Dict[str, VariableTracker]
A major benefit of this structure is that it very elegantly matches the args to `call_function`.
We then compose a FunctoolsPartialVariable from the VariableTrackers made above.
**2) Functools partials created within compile**
In this case, we already have all the args as known VTs, and thus just compose a FunctoolsPartialVariable as we do for case (1).
For both (1) and (2) - we propagate all guards from the func, args, and keyword VTs to the FunctoolsPartialVariable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108846
Approved by: https://github.com/ezyang, https://github.com/jansel
Summary:
Original commit changeset: 33650f7cb0fb
Original Phabricator Diff: D48833682
Test Plan: See T162942232 for how we figured out that this diff caused significant numeric difference.
Reviewed By: voznesenskym
Differential Revision: D49082219
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108823
Approved by: https://github.com/xw285cornell
Summary:
Enables dynamo eager mode tracing for the following situation:
1. we have a torch.autograd.Function
2. the input to that function is a tensor subclass which is an intermediary
This is useful for float8 training UX.
Test Plan:
```
python test/dynamo/test_autograd_function.py -k intermediary_input
```
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108093
Approved by: https://github.com/bdhirsh, https://github.com/wanchaol
Marks all params/optimizer state as static addresses and a finalizer which cleans up the graph attributes when the optimizer goes out of scope.
**Note: this does not mark grads as static because this will increase memory usage significantly
There are two cases:
1. The upstream graph is cudagraphed - this case will work fine OOTB
2. The upstream graph is not cudagraphed - in this case, there will be a lot of copies introduced from the upstream (to copy the grads) into cudagraphed-owned memory, unless the user explicitly marks the grads as static. If the user does this, this will also require not deallocating the grads in zero_grad() (either the mod or optimizer version) by setting them to zero vs None. There is a PR (https://github.com/pytorch/pytorch/pull/107853) in flight to throw an error if zero_grad attempts to set static grads to None.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107504
Approved by: https://github.com/eellison
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
Adds API to mark tensor as a static input -
To make this trigger recompiles properly, I'll need to update tensor match checks to also check for this new attribute
Additional concern is memory - the tensors will be kept alive, but this is the current behavior for nn modules and parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107154
Approved by: https://github.com/eellison