From a171b0636a058d0cd059d39f39e37d5cc1d38df1 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Fri, 4 Nov 2022 08:23:54 +0000 Subject: [PATCH] Add use_lazy_shape flag to GenLazyIr class (#88444) Add use_lazy_shape flag to GenLazyIr class to allow XLA to use its custom shape class. The default value is kept to use lazy shape, so this PR does not introduce any new behaviors. PyTorch/XLA companion PR: https://github.com/pytorch/xla/pull/4111 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88444 Approved by: https://github.com/alanwaketan, https://github.com/wconstab --- torchgen/dest/lazy_ir.py | 3 ++- torchgen/gen_lazy_tensor.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py index 41b32b81dbd8..33043a5780d7 100644 --- a/torchgen/dest/lazy_ir.py +++ b/torchgen/dest/lazy_ir.py @@ -172,6 +172,7 @@ class GenLazyIR(ABC): backend_index: BackendIndex backend_name: str node_base: str + use_lazy_shape: bool @method_with_native_function def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: @@ -252,7 +253,7 @@ class GenLazyIR(ABC): ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args] reuse_ctor_args = ", ".join(ctor_args) - if schema.properties.ShapePrecompute: + if self.use_lazy_shape and schema.properties.ShapePrecompute: ctor_args.append("std::vector&& shapes") node_ctor_args = ", ".join(ctor_args) diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index 5207681cf5c8..b2b24111b0f9 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -313,6 +313,7 @@ def run_gen_lazy_tensor( per_operator_headers: bool = False, backend_name: str = default_args.backend_name, gen_forced_fallback_code: bool = False, + use_lazy_shape: bool = True, # the following arguments are temporary customization points for xla backend migration. # do not rely on them otherwise, they should be removed once migration is complete backend_namespace: str = "torch::lazy", @@ -533,7 +534,7 @@ def run_gen_lazy_tensor( ) # Generate IR node classes lazy_ir_obj = lazy_ir_generator( - backend_indices[backend_key], backend_name, node_base + backend_indices[backend_key], backend_name, node_base, use_lazy_shape ) fm.write_with_template(