mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 09:17:11 +08:00
This PR requires a little justification, but let's start with what it does first: 1. When you have a 0d CPU scalar int64/float64 tensor input to a graph, we will preallocate a backed SymInt/SymFloat corresponding to what you would get if you call item() on this tensor. This means you can freely change your input to be a Python int/float or a Tensor with an item() call and end up with exactly the same level of expressivity (specifically, you can guard on the internal SymInt/SymFloat no matter what). By default, the source of the backed SymInt/SymFloat is `L['tensor'].item()`, but if you have promoted a float input into a Tensor, we will cancel out `torch.as_tensor(L['float']).item()` into just `L['float']`. 2. We switch wrap_symfloat to use this, instead of hand crafting the new SymNodeVariable. Everything works out, except that we carefully pass the item() result to tracked fakes (and not the fake Tensor argument) OK, so why do this at all? There is some marginal benefit where now some item() calls on scalar inputs can be guarded on, but IMO this is a pretty marginal benefit, and if it was the only reason, I wouldn't do this. The real reason for this is that I need to be able to propagate fake tensors through the graphs that are produced by Dynamo, and if I am doing the old custom wrap_symfloat logic, there's no way I can do this, because ordinarily an item() call will cause an unbacked SymInt when I reallocate. The other obvious way to solve the problem above is to make a HOP alternative that item() that "bakes in" the backed SymInt its supposed to return. But this strategy seems more parsimonious, and it does have the marginal benefit I mentioned above. The main downside is that what I have to do next, is make it so that when I run tensor computation, I also apply the equivalent operations to the SymInt/SymFloat as well. That's next PR. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126245 Approved by: https://github.com/eellison ghstack dependencies: #126637
644 lines
19 KiB
Python
644 lines
19 KiB
Python
import collections
|
|
import dataclasses
|
|
import enum
|
|
from typing import Any, Optional, Union
|
|
|
|
from torch._guards import ChainedSource, GuardSource, Source
|
|
|
|
from . import utils
|
|
from .bytecode_transformation import create_call_function, create_instruction
|
|
from .utils import enum_repr
|
|
|
|
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
|
|
# so those cases are omitted intentionally
|
|
_GUARD_SOURCE_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
|
|
}
|
|
|
|
_GUARD_SOURCE_FSDP_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
|
|
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
|
|
}
|
|
|
|
_GUARD_SOURCE_NOT_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL,
|
|
}
|
|
|
|
|
|
def is_constant_source(source):
|
|
if isinstance(source, ConstantSource):
|
|
return True
|
|
try:
|
|
if source.guard_source() == GuardSource.CONSTANT:
|
|
return True
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def reconstruct_getitem(
|
|
source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
|
|
):
|
|
source.base.reconstruct(codegen)
|
|
if isinstance(source.index, Source):
|
|
source.index.reconstruct(codegen)
|
|
else:
|
|
if index_is_slice:
|
|
assert isinstance(source, GetItemSource)
|
|
codegen.append_output(codegen.create_load_const(source.unpack_slice()))
|
|
else:
|
|
codegen.append_output(codegen.create_load_const(source.index))
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class LocalSource(Source):
|
|
local_name: str
|
|
cell_or_freevar: bool = False
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(codegen.create_load(self.local_name))
|
|
|
|
def guard_source(self):
|
|
return GuardSource.LOCAL
|
|
|
|
def name(self):
|
|
return f"L[{repr(self.local_name)}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SyntheticLocalSource(Source):
|
|
local_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(codegen.create_load(self.local_name))
|
|
|
|
def guard_source(self):
|
|
return GuardSource.SYNTHETIC_LOCAL
|
|
|
|
def name(self):
|
|
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class RandomValueSource(Source):
|
|
random_call_index: int
|
|
|
|
def guard_source(self):
|
|
return GuardSource.RANDOM_VALUE
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
|
|
codegen.append_output(codegen.create_load_const(self.random_call_index))
|
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
|
|
|
def name(self):
|
|
return f"random_value_{self.random_call_index}"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(
|
|
codegen.create_load_global(self.global_name, False, add=True)
|
|
)
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return f"G[{repr(self.global_name)}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalWeakRefSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(
|
|
codegen.create_load_global(self.global_name, True, add=True)
|
|
)
|
|
codegen.extend_output(create_call_function(0, False))
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return f"G[{repr(self.global_name)}]()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class AttrSource(ChainedSource):
|
|
member: str
|
|
|
|
def __post_init__(self):
|
|
assert self.base, "Can't construct an AttrSource without a valid base source"
|
|
if "." in self.member:
|
|
member_parts = self.member.split(".")
|
|
object.__setattr__(
|
|
self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
|
|
)
|
|
object.__setattr__(self, "member", member_parts[-1])
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(codegen.create_load_attrs(self.member))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if not self.member.isidentifier():
|
|
return f"getattr({self.base.name()}, {self.member!r})"
|
|
return f"{self.base.name()}.{self.member}"
|
|
|
|
|
|
# Represents tensor.grad source. It could be represented by AttrSource as well.
|
|
# But, we could access grad field on tensor directly in C++ without going
|
|
# through the Python bytecodes. Therefore, we use a separate source for grad
|
|
# field.
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GradSource(ChainedSource):
|
|
member: str = "grad"
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(codegen.create_load_attrs(self.member))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"{self.base.name()}.{self.member}"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ParamBufferSource(AttrSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
|
|
|
|
|
|
# This source is intended to be used in places where a source is needed but it is expected
|
|
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
|
|
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
|
|
# source. Guarding on this source is an error.
|
|
#
|
|
# Example: During subclass view fake-ification, any close-over ViewFunc state should be
|
|
# symbolicized / fake-ified to avoid invalid specialization during view replay. This source
|
|
# is useful for symbols utilized in the middle of the view chain that are not expected to be
|
|
# present within the final view shape metadata.
|
|
@dataclasses.dataclass(frozen=True)
|
|
class EphemeralSource(Source):
|
|
desc: Optional[str] = None
|
|
|
|
def guard_source(self):
|
|
return GuardSource.EPHEMERAL
|
|
|
|
def name(self):
|
|
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
|
|
|
def make_guard(self):
|
|
raise NotImplementedError
|
|
|
|
def is_ephemeral(self):
|
|
return True
|
|
|
|
|
|
class TensorProperty(enum.Enum):
|
|
SIZE = 0
|
|
STRIDE = 1
|
|
STORAGE_OFFSET = 2
|
|
|
|
def method_name(self):
|
|
if self is TensorProperty.SIZE:
|
|
return "size"
|
|
elif self is TensorProperty.STRIDE:
|
|
return "stride"
|
|
elif self is TensorProperty.STORAGE_OFFSET:
|
|
return "storage_offset"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TensorPropertySource(ChainedSource):
|
|
prop: TensorProperty
|
|
idx: Optional[int] = None # None for STORAGE_OFFSET
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
if self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
else:
|
|
assert self.idx is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
codegen.append_output(codegen.create_load_attr(self.prop.method_name()))
|
|
if self.idx is not None:
|
|
codegen.append_output(codegen.create_load_const(self.idx))
|
|
codegen.extend_output(
|
|
create_call_function(1 if self.idx is not None else 0, True)
|
|
)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if self.prop is TensorProperty.SIZE:
|
|
return f"{self.base.name()}.size()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STRIDE:
|
|
return f"{self.base.name()}.stride()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
return f"{self.base.name()}.storage_offset()"
|
|
else:
|
|
raise AssertionError(f"unhandled {self.prop}")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NegateSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
raise NotImplementedError
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
# NB: use method call so that function stripping regexes work
|
|
return f"{self.base.name()}.__neg__()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConvertIntSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FlattenScriptObjectSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"{self.base.name()}.__obj_flatten__()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ScriptObjectQualifiedNameSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"{self.base.name()}._type().qualified_name()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class DefaultsSource(ChainedSource):
|
|
idx_key: Union[int, str]
|
|
is_kw: bool = False
|
|
field: str = dataclasses.field(init=False, repr=False, compare=False)
|
|
_name: str = dataclasses.field(init=False, repr=False, compare=False)
|
|
|
|
def __post_init__(self):
|
|
assert (
|
|
self.base
|
|
), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
|
|
if self.is_kw:
|
|
assert isinstance(self.idx_key, str)
|
|
object.__setattr__(self, "field", "__kwdefaults__")
|
|
object.__setattr__(
|
|
self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
|
|
)
|
|
else:
|
|
assert isinstance(self.idx_key, int)
|
|
object.__setattr__(self, "field", "__defaults__")
|
|
object.__setattr__(
|
|
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
|
|
)
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(codegen.create_load_attrs(self.field))
|
|
codegen.append_output(codegen.create_load_const(self.idx_key))
|
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return self._name
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GetItemSource(ChainedSource):
|
|
index: Any
|
|
index_is_slice: bool = False
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
if isinstance(self.index, slice):
|
|
# store the hashable version of the slice so the whole GetItemSource is hashable
|
|
super().__setattr__("index", self.index.__reduce__())
|
|
super().__setattr__("index_is_slice", True)
|
|
|
|
def reconstruct(self, codegen):
|
|
reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice)
|
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def unpack_slice(self):
|
|
assert self.index_is_slice
|
|
slice_class, slice_args = self.index
|
|
return slice_class(*slice_args)
|
|
|
|
def name(self):
|
|
# Index can be of following types
|
|
# 1) ConstDictKeySource
|
|
# 2) enum.Enum
|
|
# 3) index is a slice - example 1:4
|
|
# 4) index is a constant - example string, integer
|
|
if isinstance(self.index, Source):
|
|
if not isinstance(self.index, ConstDictKeySource):
|
|
raise ValueError(
|
|
"GetItemSource index must be a constant, enum or ConstDictKeySource"
|
|
)
|
|
return f"{self.base.name()}[{self.index.name()}]"
|
|
elif self.index_is_slice:
|
|
return f"{self.base.name()}[{self.unpack_slice()!r}]"
|
|
elif isinstance(self.index, enum.Enum):
|
|
return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
|
|
else:
|
|
return f"{self.base.name()}[{self.index!r}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConstDictKeySource(GetItemSource):
|
|
def is_dict_key(self):
|
|
return True
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from(utils.__name__, "dict_keys_getitem")
|
|
self.base.reconstruct(codegen)
|
|
codegen.append_output(codegen.create_load_const(self.index))
|
|
codegen.extend_output(create_call_function(2, True))
|
|
|
|
def name(self):
|
|
# The list creation will be CSE'd by PyExprCSEPass
|
|
return f"list({self.base.name()}.keys())[{self.index!r}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TupleIteratorGetItemSource(GetItemSource):
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
|
self.base.reconstruct(codegen)
|
|
codegen.append_output(codegen.create_load_const(self.index))
|
|
codegen.extend_output(create_call_function(2, True))
|
|
|
|
def name(self):
|
|
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TypeSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from("builtins", "type")
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(create_call_function(1, True))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"type({self.base.name()})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ODictGetItemSource(ChainedSource):
|
|
index: Any
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(
|
|
codegen._create_load_const(collections.OrderedDict.__getitem__)
|
|
)
|
|
reconstruct_getitem(self, codegen, index_is_slice=False)
|
|
codegen.extend_output(create_call_function(2, True))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if isinstance(self.index, type):
|
|
rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
|
|
return f"___odict_getitem({self.base.name()}, {rep})"
|
|
elif isinstance(self.index, Source):
|
|
return f"___odict_getitem({self.base.name()}, {self.index.name()})"
|
|
else:
|
|
return f"___odict_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class OptimizerSource(ChainedSource):
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return self.base.name()
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NNModuleSource(ChainedSource):
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
|
|
|
|
def name(self):
|
|
return self.base.name()
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NotNNModuleSource(NNModuleSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FSDPNNModuleSource(NNModuleSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalStateSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConstantSource(Source):
|
|
source_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(
|
|
codegen.create_load_global(self.source_name, False, add=False)
|
|
)
|
|
|
|
def guard_source(self):
|
|
return GuardSource.CONSTANT
|
|
|
|
def name(self):
|
|
return self.source_name
|
|
|
|
def make_guard(self, fn):
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NumpyTensorSource(ChainedSource):
|
|
def name(self) -> str:
|
|
return f"___from_numpy({self.base.name()})"
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from("torch", "as_tensor")
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(create_call_function(1, True))
|
|
|
|
|
|
# NB: We don't expect you to actually ever generate guards against this
|
|
# source, it is ephemeral
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FloatTensorSource(ChainedSource):
|
|
def name(self) -> str:
|
|
return f"___as_tensor({self.base.name()})"
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CallMethodItemSource(ChainedSource):
|
|
def name(self) -> str:
|
|
return f"{self.base.name()}.item()"
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
|
|
# This is a synthetic source that is associated with the singleton
|
|
# shape env guard we always register for all frames. We get the actual
|
|
# guard contents from the ambient ShapeEnv
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ShapeEnvSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.SHAPE_ENV
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class BackwardStateSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.BACKWARD_STATE
|
|
|
|
|
|
def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
|
|
if isinstance(source, ChainedSource):
|
|
return is_from_local_source(
|
|
source.base, allow_cell_or_freevar=allow_cell_or_freevar
|
|
)
|
|
if not isinstance(source, LocalSource):
|
|
return False
|
|
if not allow_cell_or_freevar and source.cell_or_freevar:
|
|
return False
|
|
return True
|
|
|
|
|
|
def is_from_flatten_script_object_source(source: Source):
|
|
if isinstance(source, FlattenScriptObjectSource):
|
|
return True
|
|
elif isinstance(source, ChainedSource):
|
|
return is_from_flatten_script_object_source(source.base)
|
|
return False
|
|
|
|
|
|
def is_from_optimizer_source(source: Source):
|
|
if isinstance(source, OptimizerSource):
|
|
return True
|
|
if isinstance(source, ChainedSource):
|
|
return is_from_optimizer_source(source.base)
|
|
return False
|
|
|
|
|
|
# TODO: can probably write a generic "test this on everything in the chain"
|
|
# helper
|
|
def is_from_defaults(source: Source):
|
|
if isinstance(source, DefaultsSource):
|
|
return True
|
|
if isinstance(source, ChainedSource):
|
|
return is_from_defaults(source.base)
|
|
return False
|
|
|
|
|
|
def is_cell_contents(source: Source):
|
|
return isinstance(source, AttrSource) and source.member == "cell_contents"
|