From 81388add02a4abefe5fa6e42d46f56dc0626255d Mon Sep 17 00:00:00 2001 From: Jue Wang Date: Fri, 31 Mar 2023 03:09:29 +0000 Subject: [PATCH 1/2] fine-tune feedback --- data/OIG/prepare.py | 6 ++ training/dist_clm_train.py | 14 +-- training/dist_prefixlm_train.py | 14 +-- .../finetune_Pythia-Chat-Base-7B-feedback.sh | 89 +++++++++++++++++++ training/utils/dist_args_utils.py | 21 +++++ training/utils/dist_checkpoint_utils.py | 32 ++++--- 6 files changed, 138 insertions(+), 38 deletions(-) create mode 100644 training/finetune_Pythia-Chat-Base-7B-feedback.sh diff --git a/data/OIG/prepare.py b/data/OIG/prepare.py index 3db4fec..a429047 100644 --- a/data/OIG/prepare.py +++ b/data/OIG/prepare.py @@ -28,3 +28,9 @@ open(out_path, 'wb') as outfile ): shutil.copyfileobj(infile, outfile) + +process = subprocess.run( + f"git clone https://huggingface.co/datasets/laion/community-chat-contributions {DIR}/contributions", + shell=True, + check=True +) \ No newline at end of file diff --git a/training/dist_clm_train.py b/training/dist_clm_train.py index 3528655..ebc8b0b 100644 --- a/training/dist_clm_train.py +++ b/training/dist_clm_train.py @@ -220,24 +220,12 @@ def main(): add_training_hyper_parameter_arguments(parser) add_mixed_precision_arguments(parser) add_parallel_schema_arguments(parser) - parser.add_argument('--model-name', type=str, default='gpt2', metavar='S', - help='model name or path') - parser.add_argument('--tokenizer-name', type=str, default='gpt2', metavar='S', - help='tokenizer name or path') - parser.add_argument('--model-type', type=str, default='gpt2', metavar='S', - help='model name or path') - parser.add_argument('--checkpoint-path', type=str, default='model_checkpoints/gpt2') + add_ckpt_arguments(parser) parser.add_argument('--task-name', type=str, default='cot', metavar='S', help='task name') parser.add_argument('--warmup-steps', type=int, default=0, help='-') parser.add_argument('--train-warmup-steps', type=int, default=0, help='-') parser.add_argument('--total-steps', type=int, default=None, help='-') - parser.add_argument('--load-pretrained-model', - type=lambda x: x.lower()=='true', default=True, metavar='S', - help='load pretrained model or not.') - parser.add_argument('--load-checkpoint', - type=lambda x: x.lower()=='true', default=True, metavar='S', - help='load pretrained model or not.') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--profiling', type=str, default='no-profiling', metavar='S', diff --git a/training/dist_prefixlm_train.py b/training/dist_prefixlm_train.py index 86c21bf..032b834 100644 --- a/training/dist_prefixlm_train.py +++ b/training/dist_prefixlm_train.py @@ -196,24 +196,12 @@ def main(): add_training_hyper_parameter_arguments(parser) add_mixed_precision_arguments(parser) add_parallel_schema_arguments(parser) - parser.add_argument('--model-name', type=str, default='gpt2', metavar='S', - help='model name or path') - parser.add_argument('--tokenizer-name', type=str, default='gpt2', metavar='S', - help='tokenizer name or path') - parser.add_argument('--model-type', type=str, default='gpt2', metavar='S', - help='model name or path') - parser.add_argument('--checkpoint-path', type=str, default='model_checkpoints/gpt2') + add_ckpt_arguments(parser) parser.add_argument('--task-name', type=str, default='cot', metavar='S', help='task name') parser.add_argument('--warmup-steps', type=int, default=0, help='-') parser.add_argument('--train-warmup-steps', type=int, default=0, help='-') parser.add_argument('--total-steps', type=int, default=None, help='-') - parser.add_argument('--load-pretrained-model', - type=lambda x: x.lower()=='true', default=True, metavar='S', - help='load pretrained model or not.') - parser.add_argument('--load-checkpoint', - type=lambda x: x.lower()=='true', default=True, metavar='S', - help='load pretrained model or not.') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--profiling', type=str, default='no-profiling', metavar='S', diff --git a/training/finetune_Pythia-Chat-Base-7B-feedback.sh b/training/finetune_Pythia-Chat-Base-7B-feedback.sh new file mode 100644 index 0000000..451b954 --- /dev/null +++ b/training/finetune_Pythia-Chat-Base-7B-feedback.sh @@ -0,0 +1,89 @@ +DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) + +netif=lo +export GLOO_SOCKET_IFNAME=${netif} +export NCCL_SOCKET_IFNAME=${netif} +export MODEL_NAME=Pythia-Chat-Base-7B-feedback + +CKPT_LOAD_PATH=${DIR}/../model_ckpts/Pythia-Chat-Base-7B +CKPT_SAVE_PATH=${DIR}/../model_ckpts/${MODEL_NAME} + +export SHOW_DATA=0 + +BASE_MODEL="${DIR}/../pretrained/Pythia-6.9B-deduped/EleutherAI_pythia-6.9b-deduped/" + +CHECKPOINT_STEPS=100 + +DATASETS="\ +${DIR}/../data/OIG/files/unified_ni.jsonl:0.2,\ +${DIR}/../data/OIG/files/unified_p3.jsonl:0.5,\ +${DIR}/../data/OIG/files/unified_flan.jsonl:0.2,\ +${DIR}/../data/OIG/files/unified_chip2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_rallio_safety_and_prosocial.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_soda_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_unifiedskg_instructions.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_merged_code_xp3.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_oscar_en_sample_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_ul2_plus_oscar_en_sample_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_multi_news.jsonl:0.05,\ +${DIR}/../data/OIG/files/unified_openai_summarize_tldr.jsonl:0.05,\ +${DIR}/../data/OIG/files/unified_squad_v2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_nq.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_poetry_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_sqlv2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_unnatural_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_conv_finqa.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_essays.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_plot_screenplay_books_dialog.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_grade_school_math_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_mathqa_flanv2_kojma_cot.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_joke_explanations.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_cuad.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_abstract_infill.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_image_prompts_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/contributions/together_user_feedback_v0.2.jsonl:5.0 \ +" + +ARGS="--model-name ${BASE_MODEL} \ +--tokenizer-name ${BASE_MODEL} \ +--project-name together \ +--model-type gptneox \ +--optimizer adam \ +--seed 42 \ +--load-pretrained-model true \ +--task-name \ +"${DATASETS}" \ +--checkpoint-load-path ${CKPT_LOAD_PATH} \ +--checkpoint-path ${CKPT_SAVE_PATH} \ +--init-steps true \ +--total-steps 400 --warmup-steps 10 --train-warmup-steps 0 \ +--checkpoint-steps ${CHECKPOINT_STEPS} \ +--lr 1e-5 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \ +--dist-url tcp://127.0.0.1:7033 \ +--num-layers 8 --embedding-dim 4096 \ +--world-size 8 --pipeline-group-size 4 --data-group-size 2 \ +--job-id 0 --net-interface ${netif} \ +--fp16 \ +--dp-backend nccl \ +--dp-mode allreduce \ +--pp-mode gpipe --profiling no-profiling" + + +(trap 'kill 0' SIGINT; \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \ + & \ +wait) diff --git a/training/utils/dist_args_utils.py b/training/utils/dist_args_utils.py index ec254d7..8b0ba02 100644 --- a/training/utils/dist_args_utils.py +++ b/training/utils/dist_args_utils.py @@ -52,6 +52,27 @@ def add_model_arguments(parser): help='-') parser.add_argument('--num-heads', type=int, default=12, metavar='N', help='-') + +def add_ckpt_arguments(parser): + parser.add_argument('--model-name', type=str, default='gpt2', metavar='S', + help='model name or path') + parser.add_argument('--tokenizer-name', type=str, default='gpt2', metavar='S', + help='tokenizer name or path') + parser.add_argument('--model-type', type=str, default='gpt2', metavar='S', + help='model name or path') + parser.add_argument('--checkpoint-load-path', type=str, default=None, + help='Path to the ckpt to load. If none, it will be set to `checkpoint-path`') + parser.add_argument('--checkpoint-path', type=str, default='model_checkpoints/gpt2', + help='Path where ckpts are saved.') + parser.add_argument('--load-pretrained-model', + type=lambda x: x.lower()=='true', default=True, metavar='S', + help='load pretrained model or not.') + parser.add_argument('--load-checkpoint', + type=lambda x: x.lower()=='true', default=True, metavar='S', + help='load pretrained model or not.') + parser.add_argument('--init-steps', + type=lambda x: x.lower()=='true', default=False, metavar='S', + help='init steps to 0, affect lr scheduler.') def add_training_hyper_parameter_arguments(parser): diff --git a/training/utils/dist_checkpoint_utils.py b/training/utils/dist_checkpoint_utils.py index e2c4921..19955d9 100644 --- a/training/utils/dist_checkpoint_utils.py +++ b/training/utils/dist_checkpoint_utils.py @@ -10,14 +10,20 @@ def load_checkpoint(pipe, args): - if os.path.isfile(os.path.join(args.checkpoint_path, 'latest')): - with open(os.path.join(args.checkpoint_path, 'latest')) as f: + checkpoint_load_path = getattr(args, 'checkpoint_load_path', None) + if checkpoint_load_path is None: + checkpoint_load_path = args.checkpoint_path + + init_steps = getattr(args, 'init_steps', True) + + if os.path.isfile(os.path.join(checkpoint_load_path, 'latest')): + with open(os.path.join(checkpoint_load_path, 'latest')) as f: latest_step = int(f.read()) else: print('no checkpoint available, skipping') return - checkpoint_step_path = os.path.join(args.checkpoint_path, f"checkpoint_{latest_step}") + checkpoint_step_path = os.path.join(checkpoint_load_path, f"checkpoint_{latest_step}") try: with open(os.path.join(checkpoint_step_path, 'meta.json')) as f: @@ -25,7 +31,8 @@ def load_checkpoint(pipe, args): except: print('failed to load meta.') - pipe.global_step = latest_step + if init_steps: + pipe.global_step = latest_step try: pipe.model.model.load_state_dict( @@ -49,16 +56,17 @@ def load_checkpoint(pipe, args): except: print('failed to load optim states.') - try: - pipe.scheduler.load_state_dict( - torch.load( - os.path.join( - checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt' + if init_steps: + try: + pipe.scheduler.load_state_dict( + torch.load( + os.path.join( + checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt' + ) ) ) - ) - except: - print('failed to load scheduler states.') + except: + print('failed to load scheduler states.') def save_checkpoint(pipe, args): From e4148883d050a31bfe8333ae451fd8060a70a683 Mon Sep 17 00:00:00 2001 From: Jue Wang Date: Fri, 31 Mar 2023 03:09:29 +0000 Subject: [PATCH 2/2] finetune feedback --- data/OIG/prepare.py | 6 ++ training/dist_clm_train.py | 14 +-- training/dist_prefixlm_train.py | 14 +-- .../finetune_Pythia-Chat-Base-7B-feedback.sh | 89 +++++++++++++++++++ .../dist_gpipe_pipeline_async.py | 1 + training/utils/dist_args_utils.py | 21 +++++ training/utils/dist_checkpoint_utils.py | 32 ++++--- 7 files changed, 139 insertions(+), 38 deletions(-) create mode 100644 training/finetune_Pythia-Chat-Base-7B-feedback.sh diff --git a/data/OIG/prepare.py b/data/OIG/prepare.py index 3db4fec..a429047 100644 --- a/data/OIG/prepare.py +++ b/data/OIG/prepare.py @@ -28,3 +28,9 @@ open(out_path, 'wb') as outfile ): shutil.copyfileobj(infile, outfile) + +process = subprocess.run( + f"git clone https://huggingface.co/datasets/laion/community-chat-contributions {DIR}/contributions", + shell=True, + check=True +) \ No newline at end of file diff --git a/training/dist_clm_train.py b/training/dist_clm_train.py index 3528655..ebc8b0b 100644 --- a/training/dist_clm_train.py +++ b/training/dist_clm_train.py @@ -220,24 +220,12 @@ def main(): add_training_hyper_parameter_arguments(parser) add_mixed_precision_arguments(parser) add_parallel_schema_arguments(parser) - parser.add_argument('--model-name', type=str, default='gpt2', metavar='S', - help='model name or path') - parser.add_argument('--tokenizer-name', type=str, default='gpt2', metavar='S', - help='tokenizer name or path') - parser.add_argument('--model-type', type=str, default='gpt2', metavar='S', - help='model name or path') - parser.add_argument('--checkpoint-path', type=str, default='model_checkpoints/gpt2') + add_ckpt_arguments(parser) parser.add_argument('--task-name', type=str, default='cot', metavar='S', help='task name') parser.add_argument('--warmup-steps', type=int, default=0, help='-') parser.add_argument('--train-warmup-steps', type=int, default=0, help='-') parser.add_argument('--total-steps', type=int, default=None, help='-') - parser.add_argument('--load-pretrained-model', - type=lambda x: x.lower()=='true', default=True, metavar='S', - help='load pretrained model or not.') - parser.add_argument('--load-checkpoint', - type=lambda x: x.lower()=='true', default=True, metavar='S', - help='load pretrained model or not.') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--profiling', type=str, default='no-profiling', metavar='S', diff --git a/training/dist_prefixlm_train.py b/training/dist_prefixlm_train.py index 86c21bf..032b834 100644 --- a/training/dist_prefixlm_train.py +++ b/training/dist_prefixlm_train.py @@ -196,24 +196,12 @@ def main(): add_training_hyper_parameter_arguments(parser) add_mixed_precision_arguments(parser) add_parallel_schema_arguments(parser) - parser.add_argument('--model-name', type=str, default='gpt2', metavar='S', - help='model name or path') - parser.add_argument('--tokenizer-name', type=str, default='gpt2', metavar='S', - help='tokenizer name or path') - parser.add_argument('--model-type', type=str, default='gpt2', metavar='S', - help='model name or path') - parser.add_argument('--checkpoint-path', type=str, default='model_checkpoints/gpt2') + add_ckpt_arguments(parser) parser.add_argument('--task-name', type=str, default='cot', metavar='S', help='task name') parser.add_argument('--warmup-steps', type=int, default=0, help='-') parser.add_argument('--train-warmup-steps', type=int, default=0, help='-') parser.add_argument('--total-steps', type=int, default=None, help='-') - parser.add_argument('--load-pretrained-model', - type=lambda x: x.lower()=='true', default=True, metavar='S', - help='load pretrained model or not.') - parser.add_argument('--load-checkpoint', - type=lambda x: x.lower()=='true', default=True, metavar='S', - help='load pretrained model or not.') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--profiling', type=str, default='no-profiling', metavar='S', diff --git a/training/finetune_Pythia-Chat-Base-7B-feedback.sh b/training/finetune_Pythia-Chat-Base-7B-feedback.sh new file mode 100644 index 0000000..451b954 --- /dev/null +++ b/training/finetune_Pythia-Chat-Base-7B-feedback.sh @@ -0,0 +1,89 @@ +DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) + +netif=lo +export GLOO_SOCKET_IFNAME=${netif} +export NCCL_SOCKET_IFNAME=${netif} +export MODEL_NAME=Pythia-Chat-Base-7B-feedback + +CKPT_LOAD_PATH=${DIR}/../model_ckpts/Pythia-Chat-Base-7B +CKPT_SAVE_PATH=${DIR}/../model_ckpts/${MODEL_NAME} + +export SHOW_DATA=0 + +BASE_MODEL="${DIR}/../pretrained/Pythia-6.9B-deduped/EleutherAI_pythia-6.9b-deduped/" + +CHECKPOINT_STEPS=100 + +DATASETS="\ +${DIR}/../data/OIG/files/unified_ni.jsonl:0.2,\ +${DIR}/../data/OIG/files/unified_p3.jsonl:0.5,\ +${DIR}/../data/OIG/files/unified_flan.jsonl:0.2,\ +${DIR}/../data/OIG/files/unified_chip2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_rallio_safety_and_prosocial.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_soda_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_unifiedskg_instructions.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_merged_code_xp3.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_oscar_en_sample_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_ul2_plus_oscar_en_sample_dialog.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_multi_news.jsonl:0.05,\ +${DIR}/../data/OIG/files/unified_openai_summarize_tldr.jsonl:0.05,\ +${DIR}/../data/OIG/files/unified_squad_v2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_nq.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_poetry_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_sqlv2.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_unnatural_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_conv_finqa.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_essays.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_plot_screenplay_books_dialog.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_grade_school_math_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_mathqa_flanv2_kojma_cot.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_joke_explanations.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_cuad.jsonl:0.01,\ +${DIR}/../data/OIG/files/unified_abstract_infill.jsonl:0.1,\ +${DIR}/../data/OIG/files/unified_image_prompts_instructions.jsonl:0.01,\ +${DIR}/../data/OIG/contributions/together_user_feedback_v0.2.jsonl:5.0 \ +" + +ARGS="--model-name ${BASE_MODEL} \ +--tokenizer-name ${BASE_MODEL} \ +--project-name together \ +--model-type gptneox \ +--optimizer adam \ +--seed 42 \ +--load-pretrained-model true \ +--task-name \ +"${DATASETS}" \ +--checkpoint-load-path ${CKPT_LOAD_PATH} \ +--checkpoint-path ${CKPT_SAVE_PATH} \ +--init-steps true \ +--total-steps 400 --warmup-steps 10 --train-warmup-steps 0 \ +--checkpoint-steps ${CHECKPOINT_STEPS} \ +--lr 1e-5 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \ +--dist-url tcp://127.0.0.1:7033 \ +--num-layers 8 --embedding-dim 4096 \ +--world-size 8 --pipeline-group-size 4 --data-group-size 2 \ +--job-id 0 --net-interface ${netif} \ +--fp16 \ +--dp-backend nccl \ +--dp-mode allreduce \ +--pp-mode gpipe --profiling no-profiling" + + +(trap 'kill 0' SIGINT; \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \ + & \ +python ${DIR}/dist_clm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \ + & \ +wait) diff --git a/training/pipeline_parallel/dist_gpipe_pipeline_async.py b/training/pipeline_parallel/dist_gpipe_pipeline_async.py index 4457fbc..ac700dd 100644 --- a/training/pipeline_parallel/dist_gpipe_pipeline_async.py +++ b/training/pipeline_parallel/dist_gpipe_pipeline_async.py @@ -515,6 +515,7 @@ def backward_stage(self, cached_output_micro_batches: List[torch.Tensor], target { 'loss': sum(tr_loss)/len(tr_loss), 'lr': self.scheduler.get_last_lr()[0], + 'step': self.global_step, }, step=self.global_step, ) diff --git a/training/utils/dist_args_utils.py b/training/utils/dist_args_utils.py index ec254d7..8b0ba02 100644 --- a/training/utils/dist_args_utils.py +++ b/training/utils/dist_args_utils.py @@ -52,6 +52,27 @@ def add_model_arguments(parser): help='-') parser.add_argument('--num-heads', type=int, default=12, metavar='N', help='-') + +def add_ckpt_arguments(parser): + parser.add_argument('--model-name', type=str, default='gpt2', metavar='S', + help='model name or path') + parser.add_argument('--tokenizer-name', type=str, default='gpt2', metavar='S', + help='tokenizer name or path') + parser.add_argument('--model-type', type=str, default='gpt2', metavar='S', + help='model name or path') + parser.add_argument('--checkpoint-load-path', type=str, default=None, + help='Path to the ckpt to load. If none, it will be set to `checkpoint-path`') + parser.add_argument('--checkpoint-path', type=str, default='model_checkpoints/gpt2', + help='Path where ckpts are saved.') + parser.add_argument('--load-pretrained-model', + type=lambda x: x.lower()=='true', default=True, metavar='S', + help='load pretrained model or not.') + parser.add_argument('--load-checkpoint', + type=lambda x: x.lower()=='true', default=True, metavar='S', + help='load pretrained model or not.') + parser.add_argument('--init-steps', + type=lambda x: x.lower()=='true', default=False, metavar='S', + help='init steps to 0, affect lr scheduler.') def add_training_hyper_parameter_arguments(parser): diff --git a/training/utils/dist_checkpoint_utils.py b/training/utils/dist_checkpoint_utils.py index e2c4921..44de57d 100644 --- a/training/utils/dist_checkpoint_utils.py +++ b/training/utils/dist_checkpoint_utils.py @@ -10,14 +10,20 @@ def load_checkpoint(pipe, args): - if os.path.isfile(os.path.join(args.checkpoint_path, 'latest')): - with open(os.path.join(args.checkpoint_path, 'latest')) as f: + checkpoint_load_path = getattr(args, 'checkpoint_load_path', None) + if checkpoint_load_path is None: + checkpoint_load_path = args.checkpoint_path + + init_steps = getattr(args, 'init_steps', True) + + if os.path.isfile(os.path.join(checkpoint_load_path, 'latest')): + with open(os.path.join(checkpoint_load_path, 'latest')) as f: latest_step = int(f.read()) else: print('no checkpoint available, skipping') return - checkpoint_step_path = os.path.join(args.checkpoint_path, f"checkpoint_{latest_step}") + checkpoint_step_path = os.path.join(checkpoint_load_path, f"checkpoint_{latest_step}") try: with open(os.path.join(checkpoint_step_path, 'meta.json')) as f: @@ -25,7 +31,8 @@ def load_checkpoint(pipe, args): except: print('failed to load meta.') - pipe.global_step = latest_step + if not init_steps: + pipe.global_step = latest_step try: pipe.model.model.load_state_dict( @@ -49,16 +56,17 @@ def load_checkpoint(pipe, args): except: print('failed to load optim states.') - try: - pipe.scheduler.load_state_dict( - torch.load( - os.path.join( - checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt' + if not init_steps: + try: + pipe.scheduler.load_state_dict( + torch.load( + os.path.join( + checkpoint_step_path, f'prank_{get_pipeline_parallel_rank()}_scheduler.pt' + ) ) ) - ) - except: - print('failed to load scheduler states.') + except: + print('failed to load scheduler states.') def save_checkpoint(pipe, args):