Add use_lazy_shape flag to GenLazyIr class (#88444)

Add use_lazy_shape flag to GenLazyIr class to allow XLA to use its custom shape class. The default value is kept to use lazy shape, so this PR does not introduce any new behaviors.

PyTorch/XLA companion PR: https://github.com/pytorch/xla/pull/4111
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88444
Approved by: https://github.com/alanwaketan, https://github.com/wconstab
This commit is contained in:
Wonjoo Lee
2022-11-04 08:23:54 +00:00
committed by PyTorch MergeBot
parent b3206268ac
commit a171b0636a
2 changed files with 4 additions and 2 deletions

View File

@ -172,6 +172,7 @@ class GenLazyIR(ABC):
backend_index: BackendIndex
backend_name: str
node_base: str
use_lazy_shape: bool
@method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
@ -252,7 +253,7 @@ class GenLazyIR(ABC):
ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
reuse_ctor_args = ", ".join(ctor_args)
if schema.properties.ShapePrecompute:
if self.use_lazy_shape and schema.properties.ShapePrecompute:
ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
node_ctor_args = ", ".join(ctor_args)

View File

@ -313,6 +313,7 @@ def run_gen_lazy_tensor(
per_operator_headers: bool = False,
backend_name: str = default_args.backend_name,
gen_forced_fallback_code: bool = False,
use_lazy_shape: bool = True,
# the following arguments are temporary customization points for xla backend migration.
# do not rely on them otherwise, they should be removed once migration is complete
backend_namespace: str = "torch::lazy",
@ -533,7 +534,7 @@ def run_gen_lazy_tensor(
)
# Generate IR node classes
lazy_ir_obj = lazy_ir_generator(
backend_indices[backend_key], backend_name, node_base
backend_indices[backend_key], backend_name, node_base, use_lazy_shape
)
fm.write_with_template(