This PR updates the replace_fn function when loading inference checkpoints. The container will now be passed to the load_model_with_checkpoint() so we can call load_params() from there. load_params() is also updated to access the variables in the policy.
This PR refactors the organization of meta tensor checkpoint loading as follows:
- Move get_param_names() abstract method definition from TransformerPolicy into MetaTensorContainer
- Model-specific get_param_names() definitions moved from policy into model-specific container
- selected_policy_g, megatron_v2_g, and transformer_config_g globals replaced with a single container_g global, since the container will contain all of the information those globals previously captured
- ckpt_load_enabled flag added to containers that's set to False by default in the base.py container and gets set to True when the MetaTensorContainer feature is inherited
- Assertion added to replace_transformer_layer before performing checkpoint loading to check if ckpt_load_enabled ==True, otherwise an error message will be printed saying that the container does not support meta tensor checkpoint loading.
The aim of these changes is to more closely couple meta tensor checkpoint loading code to the MetaTensorContainer and to allow for better error reporting of load checkpoint use on model types that don't support this feature.
This PR cleans up some container items and removes an unused qkv_merging parameter:
- Remove qkv_merging=True from BERT containers
- Change containers config object to ds_model_config
- Remove qkv_merging param