mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add use_lazy_shape flag to GenLazyIr class (#88444)
Add use_lazy_shape flag to GenLazyIr class to allow XLA to use its custom shape class. The default value is kept to use lazy shape, so this PR does not introduce any new behaviors. PyTorch/XLA companion PR: https://github.com/pytorch/xla/pull/4111 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88444 Approved by: https://github.com/alanwaketan, https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
b3206268ac
commit
a171b0636a
@ -172,6 +172,7 @@ class GenLazyIR(ABC):
|
||||
backend_index: BackendIndex
|
||||
backend_name: str
|
||||
node_base: str
|
||||
use_lazy_shape: bool
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
||||
@ -252,7 +253,7 @@ class GenLazyIR(ABC):
|
||||
|
||||
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
|
||||
reuse_ctor_args = ", ".join(ctor_args)
|
||||
if schema.properties.ShapePrecompute:
|
||||
if self.use_lazy_shape and schema.properties.ShapePrecompute:
|
||||
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
|
||||
node_ctor_args = ", ".join(ctor_args)
|
||||
|
||||
|
@ -313,6 +313,7 @@ def run_gen_lazy_tensor(
|
||||
per_operator_headers: bool = False,
|
||||
backend_name: str = default_args.backend_name,
|
||||
gen_forced_fallback_code: bool = False,
|
||||
use_lazy_shape: bool = True,
|
||||
# the following arguments are temporary customization points for xla backend migration.
|
||||
# do not rely on them otherwise, they should be removed once migration is complete
|
||||
backend_namespace: str = "torch::lazy",
|
||||
@ -533,7 +534,7 @@ def run_gen_lazy_tensor(
|
||||
)
|
||||
# Generate IR node classes
|
||||
lazy_ir_obj = lazy_ir_generator(
|
||||
backend_indices[backend_key], backend_name, node_base
|
||||
backend_indices[backend_key], backend_name, node_base, use_lazy_shape
|
||||
)
|
||||
|
||||
fm.write_with_template(
|
||||
|
Reference in New Issue
Block a user