diff --git a/chat/generate.py b/chat/generate.py index 64a3905..c67d7d8 100644 --- a/chat/generate.py +++ b/chat/generate.py @@ -127,7 +127,7 @@ def main(): print() raw_model_name = args.model_id.split("/")[-1] - model_name = f"{raw_model_name}-{args.prompt_type}" + model_name = f"{raw_model_name}" if args.revision is not None: model_name += f"-{args.revision}" diff --git a/finetune/finetune.py b/finetune/finetune.py index 96ab961..525b37f 100644 --- a/finetune/finetune.py +++ b/finetune/finetune.py @@ -267,6 +267,8 @@ def run_training(args, train_data, val_data): output_dir=args.output_dir, dataloader_drop_last=True, evaluation_strategy="steps", + save_strategy="steps", + load_best_model_at_end=True, max_steps=args.max_steps, eval_steps=args.eval_freq, save_steps=args.save_freq, @@ -309,4 +311,4 @@ def main(args): logging.set_verbosity_error() - main(args) \ No newline at end of file + main(args)