mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Allow boolean FSDP options in fsdp_config (#30439)
* Allow boolean FSDP options in fsdp_config * Use lower() to be safe
This commit is contained in:
parent
73014b561d
commit
80126f98d8
@ -1840,12 +1840,12 @@ class TrainingArguments:
|
||||
)
|
||||
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
|
||||
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
|
||||
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false")
|
||||
os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
|
||||
|
||||
sync_module_states = self.fsdp_config.get("sync_module_states", "true")
|
||||
cpu_ram_efficient_loading = self.fsdp_config.get("cpu_ram_efficient_loading", "false")
|
||||
sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
|
||||
cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
|
||||
|
||||
if str(sync_module_states).lower() == "false" and str(cpu_ram_efficient_loading).lower() == "true":
|
||||
if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
|
||||
# In this case, all the processes except the main process would have random weights leading
|
||||
# to unexpected behaviour during training, thus throwing error here to prevent it.
|
||||
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
|
||||
@ -1853,7 +1853,7 @@ class TrainingArguments:
|
||||
os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states
|
||||
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
|
||||
|
||||
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")
|
||||
os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
|
||||
|
||||
if is_accelerate_available():
|
||||
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
|
||||
|
Loading…
Reference in New Issue
Block a user