C:\Users\Mart.paddlehub\dataset

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddlehub as hub
import os, io, csv
from paddlehub.datasets.base_nlp_dataset import InputExample, TextClassificationDataset

class ThuNews(TextClassificationDataset):
    def __init__(self, tokenizer, mode='train', max_seq_len=128):
        # 数据集存放位置
        DATA_DIR = "./thu_news"

        if mode == 'train':
            data_file = 'train.txt'
        elif mode == 'test':
            data_file = 'test.txt'
        else:
            data_file = 'valid.txt'
        super(ThuNews, self).__init__(
            base_path=DATA_DIR,
            data_file=data_file,
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
            mode=mode,
            is_file_with_header=True,
            label_list=['体育', '科技', '社会', '娱乐', '股票', '房产', '教育', '时政', '财经', '星座', '游戏', '家居', '彩票', '时尚'])

    # 解析文本文件里的样本
    def _read_file(self, input_file, is_file_with_header: bool = False):
        if not os.path.exists(input_file):
            raise RuntimeError("The file {} is not found.".format(input_file))
        else:
            with io.open(input_file, "r", encoding="UTF-8") as f:
                reader = csv.reader(f, delimiter="\t", quotechar=None)
                examples = []
                seq_id = 0
                header = next(reader) if is_file_with_header else None
                for line in reader:
                    example = InputExample(guid=seq_id, text_a=line[0], label=line[1])
                    seq_id += 1
                    examples.append(example)
                return examples


"""
官方标准的数据格式是标签在前面,现在自定义标签在后面,所以需要自定义_read_file函数把line[0]/line[1]调换下。
label	text_a
1	选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般
1	15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错
"""
if __name__ == '__main__':
    model = hub.Module(name='ernie_tiny', version='2.0.1', task='seq-cls', num_classes=14) # 在多分类任务中,num_classes需要显式地指定类别数,此处根据数据集设置为14
    # 通过以上的一行代码,model初始化为一个适用于文本分类任务的模型,为ERNIE的预训练模型后拼接上一个全连接网络(Full
    # Connected)。
    train_dataset = ThuNews(model.get_tokenizer(), mode='train', max_seq_len=128)
    dev_dataset = ThuNews(model.get_tokenizer(), mode='dev', max_seq_len=128)
    test_dataset = ThuNews(model.get_tokenizer(), mode='test', max_seq_len=128)
    optimizer = paddle.optimizer.Adam(learning_rate=5e-5, parameters=model.parameters())  # 优化器的选择和参数配置
    trainer = hub.Trainer(model, optimizer, checkpoint_dir='./ckpt', use_gpu=False)  # fine-tune任务的执行者
    trainer.train(train_dataset, epochs=3, batch_size=32, eval_dataset=dev_dataset,save_interval=1)  # 配置训练参数,启动训练,并指定验证集
    result = trainer.evaluate(test_dataset, batch_size=32)  # 在测试集上评估当前训练模型
Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐