Smangrul/fix failing ds ci tests (#27358)

* fix failing DeepSpeed CI tests due to `safetensors` being default

* debug

* remove debug statements

* resolve comments

* Update test_deepspeed.py
This commit is contained in:
Sourab Mangrulkar 2023-11-09 11:47:24 +05:30 committed by GitHub
parent ced9fd86f5
commit 7ecd229ba4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -48,7 +48,7 @@ from transformers.testing_utils import (
slow,
)
from transformers.trainer_utils import get_last_checkpoint, set_seed
from transformers.utils import WEIGHTS_NAME, is_torch_bf16_gpu_available
from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_gpu_available
if is_torch_available():
@ -565,8 +565,7 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
def check_saved_checkpoints_deepspeed(self, output_dir, freq, total, stage, dtype):
# adapted from TrainerIntegrationCommon.check_saved_checkpoints
file_list = [WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
file_list = [SAFE_WEIGHTS_NAME, "training_args.bin", "trainer_state.json", "config.json"]
if stage == ZERO2:
ds_file_list = ["mp_rank_00_model_states.pt"]
@ -581,7 +580,6 @@ class TrainerIntegrationDeepSpeed(TrainerIntegrationDeepSpeedWithCustomConfig, T
for step in range(freq, total, freq):
checkpoint = os.path.join(output_dir, f"checkpoint-{step}")
self.assertTrue(os.path.isdir(checkpoint), f"[{stage}] {checkpoint} dir is not found")
# common files
for filename in file_list:
path = os.path.join(checkpoint, filename)