【Transformers基础入门篇7】基础组件之Trainer
Trainer是库中提供的训练的函数,内部封装了完整的训练、评估逻辑,并集成了多种的后端,如等,搭配对训练过程中的各项参数进行配置,可以方便快捷地启动模型 单机/分布式训练使用Trainer进行模型训练对模型的输入输出是有限制的,要求模型返回元组或者的子类如果输入中提供了labels,模型要能返回loss结果,如果是元组,要求loss为元组中的第一个值。
·
文章目录
本文为 https://space.bilibili.com/21060026/channel/collectiondetail?sid=1357748的视频学习笔记
项目地址为:https://github.com/zyds/transformers-code
一、Trainer 简介
Trainer
是transformers
库中提供的训练的函数,内部封装了完整的训练、评估逻辑,并集成了多种的后端,如DeepSpeed、Pytorch FSDP
等,搭配TrainingArguments
对训练过程中的各项参数进行配置,可以方便快捷地启动模型 单机/分布式训练- 使用
Trainer
进行模型训练对模型的输入输出是有限制的,要求模型返回元组或者ModelOutput
的子类 - 如果输入中提供了labels,模型要能返回loss结果,如果是元组,要求loss为元组中的第一个值
- 文档地址:https://huggingface.co/docs/transformers/main_classes/trainer
二、模型微调代码优化
- 任务类型:文本分类
- 使用模型 hfl/rbt3
- 使用
Trainer + TrainingArgument
优化训练流程
2.0 完整使用流程
- step1:导入包 Trainer,TrainingArguments
- step2:创建TrainingArguments
- step3: 创建Trainer
- step4:模型训练
- step5:模型评估
- step6:模型预测
2.1 导入相关包
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
2.2 加载数据集
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset = dataset.filter(lambda x: x["review"] is not None)
2.3 划分数据集
datasets = dataset.train_test_split(test_size=0.1)
2.4 数据集预处理
import torch
tokenizer = AutoTokenizer.from_pretrained("../../models/hfl/rbt3")
def process_function(examples):
tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
tokenized_examples["labels"] = examples["label"]
return tokenized_examples
tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
2.5 创建模型
model = AutoModelForSequenceClassification.from_pretrained("../../models/hfl/rbt3")
2.6 创建评估函数
mport evaluate
acc_metric = evaluate.load("../../evaluate/metrics/accuracy")
f1_metric = evaluate.load("../../evaluate/metrics/f1")
#%%
def eval_metric(eval_predict):
predictions, labels = eval_predict
predictions = predictions.argmax(axis=-1)
acc = acc_metric.compute(predictions=predictions, references=labels)
f1 = f1_metric.compute(predictions=predictions, references=labels)
acc.update(f1) # 最终acc包含accuracy和f1
return acc
2.7 创建TrainingArguments
train_args = TrainingArguments(output_dir="./checkpoints", # 输出文件夹
per_device_train_batch_size=64, # 训练时的batch_size
per_device_eval_batch_size=128, # 验证时的batch_size
logging_steps=10, # log 打印的频率
evaluation_strategy="epoch", # 评估策略
save_strategy="epoch", # 保存策略
save_total_limit=3, # 最大保存数
learning_rate=2e-5, # 学习率
weight_decay=0.01, # weight_decay
metric_for_best_model="f1", # 设定评估指标,来评价模型是最好的
load_best_model_at_end=True) # 训练完成后加载最优模型
2.8 创建Trainer
from transformers import DataCollatorWithPadding
trainer = Trainer(model=model,
args=train_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
compute_metrics=eval_metric)
2.9 模型训练
trainer.train()
2.10 模型评估
trainer.evaluate(tokenized_datasets["test"])
2.11 模型预测
trainer.predict(tokenized_datasets["test"])
2.12 pipeline测试
from transformers import pipeline
id2_label = id2_label = {0: "差评!", 1: "好评!"}
model.config.id2label = id2_label
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)
sen = "我觉得不错!"
pipe(sen)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
已为社区贡献14条内容
所有评论(0)