mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-23 23:04:52 +08:00
All the changes brought by the original PR have been addressed in alternative ways in the stack. Why the original PR has to be reverted requires more effort because there is some bad interaction with export. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131058 Approved by: https://github.com/williamwen42
674 lines
20 KiB
Python
674 lines
20 KiB
Python
# mypy: allow-untyped-defs
|
|
import collections
|
|
import dataclasses
|
|
import enum
|
|
from typing import Any, Optional, Union
|
|
|
|
from torch._guards import ChainedSource, GuardSource, Source
|
|
|
|
from . import utils
|
|
from .bytecode_transformation import create_call_function, create_instruction
|
|
from .utils import enum_repr
|
|
|
|
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
|
|
# so those cases are omitted intentionally
|
|
_GUARD_SOURCE_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
|
|
}
|
|
|
|
_GUARD_SOURCE_FSDP_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
|
|
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
|
|
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
|
|
}
|
|
|
|
_GUARD_SOURCE_NOT_NN_MODULE = {
|
|
GuardSource.LOCAL: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL: GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
|
|
GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL,
|
|
GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL,
|
|
}
|
|
|
|
|
|
def is_constant_source(source):
|
|
if isinstance(source, ConstantSource):
|
|
return True
|
|
try:
|
|
if source.guard_source() == GuardSource.CONSTANT:
|
|
return True
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
def reconstruct_getitem(
|
|
source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
|
|
):
|
|
source.base.reconstruct(codegen)
|
|
if isinstance(source.index, Source):
|
|
source.index.reconstruct(codegen)
|
|
else:
|
|
if index_is_slice:
|
|
assert isinstance(source, GetItemSource)
|
|
codegen.append_output(codegen.create_load_const(source.unpack_slice()))
|
|
else:
|
|
codegen.append_output(codegen.create_load_const(source.index))
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class LocalSource(Source):
|
|
local_name: str
|
|
cell_or_freevar: bool = False
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(codegen.create_load(self.local_name))
|
|
|
|
def guard_source(self):
|
|
return GuardSource.LOCAL
|
|
|
|
def name(self):
|
|
return f"L[{repr(self.local_name)}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SyntheticLocalSource(Source):
|
|
local_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(codegen.create_load(self.local_name))
|
|
|
|
def guard_source(self):
|
|
return GuardSource.SYNTHETIC_LOCAL
|
|
|
|
def name(self):
|
|
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class RandomValueSource(Source):
|
|
random_call_index: int
|
|
|
|
def guard_source(self):
|
|
return GuardSource.RANDOM_VALUE
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
|
|
codegen.append_output(codegen.create_load_const(self.random_call_index))
|
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
|
|
|
def name(self):
|
|
return f"random_value_{self.random_call_index}"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(codegen.create_load_global(self.global_name, add=True))
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return f"G[{repr(self.global_name)}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalWeakRefSource(Source):
|
|
global_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(
|
|
codegen.create_load_global(self.global_name, add=True)
|
|
)
|
|
)
|
|
codegen.extend_output(create_call_function(0, False))
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
def name(self):
|
|
return f"G[{repr(self.global_name)}]()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class WeakRefCallSource(ChainedSource):
|
|
def reconstruct(self, codegen):
|
|
codegen.add_push_null(lambda: self.base.reconstruct(codegen))
|
|
codegen.extend_output(create_call_function(0, False))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"{self.base.name()}()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class AttrSource(ChainedSource):
|
|
member: str
|
|
|
|
def __post_init__(self):
|
|
assert self.base, "Can't construct an AttrSource without a valid base source"
|
|
if "." in self.member:
|
|
member_parts = self.member.split(".")
|
|
object.__setattr__(
|
|
self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
|
|
)
|
|
object.__setattr__(self, "member", member_parts[-1])
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(codegen.create_load_attrs(self.member))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if not self.member.isidentifier():
|
|
return f"getattr({self.base.name()}, {self.member!r})"
|
|
return f"{self.base.name()}.{self.member}"
|
|
|
|
|
|
# Represents tensor.grad source. It could be represented by AttrSource as well.
|
|
# But, we could access grad field on tensor directly in C++ without going
|
|
# through the Python bytecodes. Therefore, we use a separate source for grad
|
|
# field.
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GradSource(ChainedSource):
|
|
member: str = "grad"
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(codegen.create_load_attrs(self.member))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"{self.base.name()}.{self.member}"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ParamBufferSource(AttrSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
|
|
|
|
|
|
# This source is intended to be used in places where a source is needed but it is expected
|
|
# that the symbol will be simplified out later on. Symbols with ephemeral sources are
|
|
# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
|
|
# source. Guarding on this source is an error.
|
|
#
|
|
# Example: During subclass view fake-ification, any close-over ViewFunc state should be
|
|
# symbolicized / fake-ified to avoid invalid specialization during view replay. This source
|
|
# is useful for symbols utilized in the middle of the view chain that are not expected to be
|
|
# present within the final view shape metadata.
|
|
@dataclasses.dataclass(frozen=True)
|
|
class EphemeralSource(Source):
|
|
desc: Optional[str] = None
|
|
|
|
def guard_source(self):
|
|
return GuardSource.EPHEMERAL
|
|
|
|
def name(self):
|
|
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
|
|
|
def make_guard(self):
|
|
raise NotImplementedError
|
|
|
|
def is_ephemeral(self):
|
|
return True
|
|
|
|
|
|
class TensorProperty(enum.Enum):
|
|
SIZE = 0
|
|
STRIDE = 1
|
|
STORAGE_OFFSET = 2
|
|
|
|
def method_name(self):
|
|
if self is TensorProperty.SIZE:
|
|
return "size"
|
|
elif self is TensorProperty.STRIDE:
|
|
return "stride"
|
|
elif self is TensorProperty.STORAGE_OFFSET:
|
|
return "storage_offset"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TensorPropertySource(ChainedSource):
|
|
prop: TensorProperty
|
|
idx: Optional[int] = None # None for STORAGE_OFFSET
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
if self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
else:
|
|
assert self.idx is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
def gen_fn():
|
|
self.base.reconstruct(codegen)
|
|
codegen.append_output(codegen.create_load_attr(self.prop.method_name()))
|
|
|
|
codegen.add_push_null(gen_fn)
|
|
if self.idx is not None:
|
|
codegen.append_output(codegen.create_load_const(self.idx))
|
|
codegen.extend_output(
|
|
create_call_function(1 if self.idx is not None else 0, False)
|
|
)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if self.prop is TensorProperty.SIZE:
|
|
return f"{self.base.name()}.size()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STRIDE:
|
|
return f"{self.base.name()}.stride()[{self.idx}]"
|
|
elif self.prop is TensorProperty.STORAGE_OFFSET:
|
|
assert self.idx is None
|
|
return f"{self.base.name()}.storage_offset()"
|
|
else:
|
|
raise AssertionError(f"unhandled {self.prop}")
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NegateSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
raise NotImplementedError
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
# NB: use method call so that function stripping regexes work
|
|
return f"{self.base.name()}.__neg__()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConvertIntSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FlattenScriptObjectSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"{self.base.name()}.__obj_flatten__()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ScriptObjectQualifiedNameSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"{self.base.name()}._type().qualified_name()"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class DefaultsSource(ChainedSource):
|
|
idx_key: Union[int, str]
|
|
is_kw: bool = False
|
|
field: str = dataclasses.field(init=False, repr=False, compare=False)
|
|
_name: str = dataclasses.field(init=False, repr=False, compare=False)
|
|
|
|
def __post_init__(self):
|
|
assert (
|
|
self.base
|
|
), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
|
|
if self.is_kw:
|
|
assert isinstance(self.idx_key, str)
|
|
object.__setattr__(self, "field", "__kwdefaults__")
|
|
object.__setattr__(
|
|
self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
|
|
)
|
|
else:
|
|
assert isinstance(self.idx_key, int)
|
|
object.__setattr__(self, "field", "__defaults__")
|
|
object.__setattr__(
|
|
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
|
|
)
|
|
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(codegen.create_load_attrs(self.field))
|
|
codegen.append_output(codegen.create_load_const(self.idx_key))
|
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return self._name
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GetItemSource(ChainedSource):
|
|
index: Any
|
|
index_is_slice: bool = False
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
if isinstance(self.index, slice):
|
|
# store the hashable version of the slice so the whole GetItemSource is hashable
|
|
super().__setattr__("index", self.index.__reduce__())
|
|
super().__setattr__("index_is_slice", True)
|
|
|
|
def reconstruct(self, codegen):
|
|
reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice)
|
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def unpack_slice(self):
|
|
assert self.index_is_slice
|
|
slice_class, slice_args = self.index
|
|
return slice_class(*slice_args)
|
|
|
|
def name(self):
|
|
# Index can be of following types
|
|
# 1) ConstDictKeySource
|
|
# 2) enum.Enum
|
|
# 3) index is a slice - example 1:4
|
|
# 4) index is a constant - example string, integer
|
|
if isinstance(self.index, Source):
|
|
if not isinstance(self.index, ConstDictKeySource):
|
|
raise ValueError(
|
|
"GetItemSource index must be a constant, enum or ConstDictKeySource"
|
|
)
|
|
return f"{self.base.name()}[{self.index.name()}]"
|
|
elif self.index_is_slice:
|
|
return f"{self.base.name()}[{self.unpack_slice()!r}]"
|
|
elif isinstance(self.index, enum.Enum):
|
|
return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
|
|
else:
|
|
return f"{self.base.name()}[{self.index!r}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConstDictKeySource(GetItemSource):
|
|
def is_dict_key(self):
|
|
return True
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.add_push_null(
|
|
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
|
|
)
|
|
self.base.reconstruct(codegen)
|
|
codegen.append_output(codegen.create_load_const(self.index))
|
|
codegen.extend_output(create_call_function(2, False))
|
|
|
|
def name(self):
|
|
# The list creation will be CSE'd by PyExprCSEPass
|
|
return f"list({self.base.name()}.keys())[{self.index!r}]"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TupleIteratorGetItemSource(GetItemSource):
|
|
def reconstruct(self, codegen):
|
|
codegen.add_push_null(
|
|
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
|
)
|
|
self.base.reconstruct(codegen)
|
|
codegen.append_output(codegen.create_load_const(self.index))
|
|
codegen.extend_output(create_call_function(2, False))
|
|
|
|
def name(self):
|
|
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class TypeSource(ChainedSource):
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(create_call_function(1, False))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return f"type({self.base.name()})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ODictGetItemSource(ChainedSource):
|
|
index: Any
|
|
|
|
def __post_init__(self):
|
|
assert self.base is not None
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.add_push_null(
|
|
lambda: codegen.append_output(
|
|
codegen._create_load_const(collections.OrderedDict.__getitem__)
|
|
)
|
|
)
|
|
reconstruct_getitem(self, codegen, index_is_slice=False)
|
|
codegen.extend_output(create_call_function(2, False))
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
if isinstance(self.index, type):
|
|
rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
|
|
return f"___odict_getitem({self.base.name()}, {rep})"
|
|
elif isinstance(self.index, Source):
|
|
return f"___odict_getitem({self.base.name()}, {self.index.name()})"
|
|
else:
|
|
return f"___odict_getitem({self.base.name()}, {self.index!r})"
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class OptimizerSource(ChainedSource):
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def name(self):
|
|
return self.base.name()
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NNModuleSource(ChainedSource):
|
|
def reconstruct(self, codegen):
|
|
self.base.reconstruct(codegen)
|
|
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()]
|
|
|
|
def name(self):
|
|
return self.base.name()
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class UnspecializedNNModuleSource(NNModuleSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FSDPNNModuleSource(NNModuleSource):
|
|
def guard_source(self):
|
|
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class GlobalStateSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.GLOBAL
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ConstantSource(Source):
|
|
source_name: str
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.append_output(codegen.create_load_global(self.source_name, add=False))
|
|
|
|
def guard_source(self):
|
|
return GuardSource.CONSTANT
|
|
|
|
def name(self):
|
|
return self.source_name
|
|
|
|
def make_guard(self, fn):
|
|
raise NotImplementedError
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class NumpyTensorSource(ChainedSource):
|
|
def name(self) -> str:
|
|
return f"___from_numpy({self.base.name()})"
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
|
|
self.base.reconstruct(codegen)
|
|
codegen.extend_output(create_call_function(1, False))
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SubclassAttrListSource(ChainedSource):
|
|
def name(self) -> str:
|
|
return f"{self.base.name()}.__tensor_flatten__()[0]"
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
|
|
# NB: We don't expect you to actually ever generate guards against this
|
|
# source, it is ephemeral
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FloatTensorSource(ChainedSource):
|
|
def name(self) -> str:
|
|
return f"___as_tensor({self.base.name()})"
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CallMethodItemSource(ChainedSource):
|
|
def name(self) -> str:
|
|
return f"{self.base.name()}.item()"
|
|
|
|
def guard_source(self):
|
|
return self.base.guard_source()
|
|
|
|
|
|
# This is a synthetic source that is associated with the singleton
|
|
# shape env guard we always register for all frames. We get the actual
|
|
# guard contents from the ambient ShapeEnv
|
|
@dataclasses.dataclass(frozen=True)
|
|
class ShapeEnvSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.SHAPE_ENV
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class BackwardStateSource(Source):
|
|
def name(self):
|
|
return ""
|
|
|
|
def guard_source(self):
|
|
return GuardSource.BACKWARD_STATE
|
|
|
|
|
|
def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
|
|
if isinstance(source, ChainedSource):
|
|
return is_from_local_source(
|
|
source.base, allow_cell_or_freevar=allow_cell_or_freevar
|
|
)
|
|
if not isinstance(source, LocalSource):
|
|
return False
|
|
if not allow_cell_or_freevar and source.cell_or_freevar:
|
|
return False
|
|
return True
|
|
|
|
|
|
def is_from_flatten_script_object_source(source: Source):
|
|
if isinstance(source, FlattenScriptObjectSource):
|
|
return True
|
|
elif isinstance(source, ChainedSource):
|
|
return is_from_flatten_script_object_source(source.base)
|
|
return False
|
|
|
|
|
|
def is_from_optimizer_source(source: Source):
|
|
if isinstance(source, OptimizerSource):
|
|
return True
|
|
if isinstance(source, ChainedSource):
|
|
return is_from_optimizer_source(source.base)
|
|
return False
|
|
|
|
|
|
# TODO: can probably write a generic "test this on everything in the chain"
|
|
# helper
|
|
def is_from_defaults(source: Source):
|
|
if isinstance(source, DefaultsSource):
|
|
return True
|
|
if isinstance(source, ChainedSource):
|
|
return is_from_defaults(source.base)
|
|
return False
|
|
|
|
|
|
def is_cell_contents(source: Source):
|
|
return isinstance(source, AttrSource) and source.member == "cell_contents"
|