paddlehub官方数据集查看地址
C:\Users\Mart.paddlehub\dataset# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in co
·
C:\Users\Mart.paddlehub\dataset
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddlehub as hub
import os, io, csv
from paddlehub.datasets.base_nlp_dataset import InputExample, TextClassificationDataset
class ThuNews(TextClassificationDataset):
def __init__(self, tokenizer, mode='train', max_seq_len=128):
# 数据集存放位置
DATA_DIR = "./thu_news"
if mode == 'train':
data_file = 'train.txt'
elif mode == 'test':
data_file = 'test.txt'
else:
data_file = 'valid.txt'
super(ThuNews, self).__init__(
base_path=DATA_DIR,
data_file=data_file,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
mode=mode,
is_file_with_header=True,
label_list=['体育', '科技', '社会', '娱乐', '股票', '房产', '教育', '时政', '财经', '星座', '游戏', '家居', '彩票', '时尚'])
# 解析文本文件里的样本
def _read_file(self, input_file, is_file_with_header: bool = False):
if not os.path.exists(input_file):
raise RuntimeError("The file {} is not found.".format(input_file))
else:
with io.open(input_file, "r", encoding="UTF-8") as f:
reader = csv.reader(f, delimiter="\t", quotechar=None)
examples = []
seq_id = 0
header = next(reader) if is_file_with_header else None
for line in reader:
example = InputExample(guid=seq_id, text_a=line[0], label=line[1])
seq_id += 1
examples.append(example)
return examples
"""
官方标准的数据格式是标签在前面,现在自定义标签在后面,所以需要自定义_read_file函数把line[0]/line[1]调换下。
label text_a
1 选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般
1 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错
"""
if __name__ == '__main__':
model = hub.Module(name='ernie_tiny', version='2.0.1', task='seq-cls', num_classes=14) # 在多分类任务中,num_classes需要显式地指定类别数,此处根据数据集设置为14
# 通过以上的一行代码,model初始化为一个适用于文本分类任务的模型,为ERNIE的预训练模型后拼接上一个全连接网络(Full
# Connected)。
train_dataset = ThuNews(model.get_tokenizer(), mode='train', max_seq_len=128)
dev_dataset = ThuNews(model.get_tokenizer(), mode='dev', max_seq_len=128)
test_dataset = ThuNews(model.get_tokenizer(), mode='test', max_seq_len=128)
optimizer = paddle.optimizer.Adam(learning_rate=5e-5, parameters=model.parameters()) # 优化器的选择和参数配置
trainer = hub.Trainer(model, optimizer, checkpoint_dir='./ckpt', use_gpu=False) # fine-tune任务的执行者
trainer.train(train_dataset, epochs=3, batch_size=32, eval_dataset=dev_dataset,save_interval=1) # 配置训练参数,启动训练,并指定验证集
result = trainer.evaluate(test_dataset, batch_size=32) # 在测试集上评估当前训练模型
更多推荐
已为社区贡献10条内容
所有评论(0)