mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
TST Add test for FSDP ignored_modules as str (#3719)
Follow up to #3698.
This commit is contained in:
@ -398,6 +398,21 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
assert fsdp_plugin.cpu_ram_efficient_loading is False
|
||||
assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "False"
|
||||
|
||||
def test_ignored_modules_regex(self):
|
||||
# Check that FSDP's ignored_modules can be a string, in which case it is treated as a regex
|
||||
if self.current_fsdp_version != 1:
|
||||
self.skipTest("ignored_modules only relevant for FSDP1")
|
||||
|
||||
env = self.fsdp_envs[1].copy()
|
||||
env["FSDP_IGNORED_MODULES"] = ".*\\.q_proj$"
|
||||
with patch_environment(**env):
|
||||
accelerator = Accelerator()
|
||||
model = AutoModel.from_pretrained(LLAMA_TESTING)
|
||||
# model has 2 layers
|
||||
layers_to_ignore = {model.layers[0].self_attn.q_proj, model.layers[1].self_attn.q_proj}
|
||||
model = accelerator.prepare(model)
|
||||
assert model._ignored_modules == layers_to_ignore
|
||||
|
||||
|
||||
@require_fsdp2
|
||||
@require_non_cpu
|
||||
|
Reference in New Issue
Block a user