新闻主题分类任务
关于新闻主题分类任务
- 以一段新闻报道中的文本描述内容为输入,使用模型帮助我们判断 它最有可能属于哪一种类型的新闻,这是典型的文本分类问题,这里假定每种类型是互斥的,即文本描述有且只有一种类型
新闻主题分类数据
通过torchtext获取数据
1 2 3 4 5 6 7 8 9 10 11 12
| import torch import torchtext import os
load_data_path = "./data"
if not os.path.isdir(load_data_path): os.mkdir(load_data_path)
train_dataset, test_dataset = torchtext.datasets.DATASETS["AG_NEWS"](root=load_data_path)
|
文件说明:
- train.csv表示训练数据,共12万条;test.csv表示验证数据,共7600条;
处理数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| from keras.preprocessing.text import Tokenizer from keras.preprocessing import sequence def process_datasets_by_Tokenizer(train_datasets, test_datasets, cutlen=256): tokenizer = Tokenizer()
train_datasets_texts = [] train_datasets_labels = [] test_datasets_taxts = [] test_datasets_labels = [] for item in train_datasets: train_datasets_labels.append(item[0] - 1) train_datasets_texts.append(item[1]) for item in test_datasets: test_datasets_labels.append(item[0] - 1) test_datasets_taxts.append(item[1])
all_datasets_texts = train_datasets_texts + test_datasets_taxts all_datasets_labels = train_datasets_labels + test_datasets_labels
tokenizer.fit_on_texts(all_datasets_texts) train_datasets_seqs = tokenizer.texts_to_sequences(train_datasets_texts) test_datasets_seqs = tokenizer.texts_to_sequences(test_datasets_taxts)
train_datasets_seqs = sequence.pad_sequences(train_datasets_seqs, cutlen) test_datasets_seqs = sequence.pad_sequences(test_datasets_seqs, cutlen)
train_datasets = list(zip(train_datasets_seqs, train_datasets_labels)) test_datasets = list(zip(test_datasets_seqs, test_datasets_labels))
vocab_size = len(tokenizer.index_word.keys()) num_class = len(set(all_datasets_labels))
return train_datasets, test_datasets, vocab_size, num_class train_datasets, test_datasets, vocab_size, num_class = process_datasets_by_Tokenizer(train_datasets, test_datasets)
|
新闻主题分类实现步骤
- 构建带有Embedding层的文本分类模型
- 对数据进行batch处理
- 构建训练与验证函数
- 进行模型训练和验证
- 查看embedding层嵌入的词向量
构建带有Embedding层的文本分类模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
| import torch.nn as nn import torch.nn.functional as F
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class TextSentiment(nn.Module): """文本分类模型""" def __init__(self, vocab_size, embed_dim, num_class): """ description: 类的初始化函数 :param vocab_size: 整个语料包含的不同词汇总数 :param embed_dim: 指定词嵌入的维度 :param num_class: 文本分类的类别总数 """ super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, sparse=True) self.fc = nn.Linear(embed_dim, num_class) self.init_weights()
def init_weights(self): """初始化权重函数""" initrange = 0.5 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_()
def forward(self, text): """ :param text: 文本数值映射后的结果 :return: 与类别数尺寸相同的张量, 用以判断文本类别 """ embedded = self.embedding(text) c = embedded.size(0) // BATCH_SIZE embedded = embedded[:BATCH_SIZE*c] embedded = embedded.transpose(1, 0).unsqueeze(0) embedded = F.avg_pool1d(embedded, kernel_size=c) return F.softmax(self.fc(embedded[0].transpose(1, 0)), dim=1)
|
1 2 3 4 5 6 7 8
| VOCAB_SIZE = vocab_size
EMBED_DIM = 32
NUM_CLASS = num_class
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)
|
对数据进行batch处理
1 2 3 4 5 6 7
| def generate_batch(batch): text = [] label = [] for item in batch: text.extend(item[0]) label.append(item[1]) return torch.tensor(text), torch.tensor(label)
|
1 2 3
| batch = [(torch.tensor([3, 23, 2, 8]), 1), (torch.tensor([3, 45, 21, 6]), 0)] print(generate_batch(batch))
|
(tensor([ 3, 23, 2, 8, 3, 45, 21, 6]), tensor([1, 0]))
构建训练与验证函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| from torch.utils.data import DataLoader def train(train_data): """模型训练函数""" train_loss = 0 train_acc = 0 data = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn = generate_batch) for i, (text, cls) in enumerate(data): optimizer.zero_grad() output = model(text.to(device)) loss = criterion(output, cls.to(device)) train_loss += loss.item() loss.backward() optimizer.step() train_acc += (output.argmax(1) == cls.to(device)).sum().item() scheduler.step() return train_loss / len(train_data), train_acc / len(train_data)
def valid(valid_data): """模型验证函数""" loss = 0 acc = 0 data = DataLoader(valid_data, batch_size=BATCH_SIZE, collate_fn = generate_batch) for text, cls in data: with torch.no_grad(): output = model(text.to(device)) loss = criterion(output, cls.to(device)) loss += loss.item() acc += (output.argmax(1) == cls.to(device)).sum().item() return loss / len(valid_data), acc / len(valid_data)
|
进行模型训练和验证
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| import time from torch.utils.data.dataset import random_split
N_EPOCHS = 10
min_valid_loss = float("inf")
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
train_len = int(len(train_datasets)*0.95) sub_train, sub_valid = random_split(train_datasets, [train_len, len(train_datasets)-train_len])
for epoch in range(N_EPOCHS): start_time = time.time() train_loss, train_acc = train(sub_train) valid_loss, valid_acc = valid(sub_valid)
secs = int(time.time() - start_time) mins = secs / 60 secs = secs % 60
print("Epoch: %d" %(epoch+1),"| time in %d minute,% seconds" %(mins, secs)) print(f"\tLoss: {train_loss:.4f}(train)\t|\tAcc:{train_loss*100:.1f}%(train)") print(f"\tLoss: {valid_loss:.4f}(valid)\t|\tAcc:{valid_loss*100:.1f}%(valid)")
|
Epoch: 1 | time in 0 minute,25econds Loss: 0.0873(train) | Acc:32.1%(train) Loss: 0.0004(valid) | Acc:35.9%(valid) Epoch: 2 | time in 0 minute,25econds Loss: 0.0826(train) | Acc:38.8%(train) Loss: 0.0005(valid) | Acc:39.4%(valid) Epoch: 3 | time in 0 minute,25econds Loss: 0.0821(train) | Acc:40.2%(train) Loss: 0.0004(valid) | Acc:34.9%(valid) Epoch: 4 | time in 0 minute,25econds Loss: 0.0804(train) | Acc:43.0%(train) Loss: 0.0005(valid) | Acc:39.1%(valid) Epoch: 5 | time in 0 minute,25econds Loss: 0.0793(train) | Acc:45.0%(train) Loss: 0.0005(valid) | Acc:31.5%(valid) Epoch: 6 | time in 0 minute,25econds Loss: 0.0778(train) | Acc:47.7%(train) Loss: 0.0004(valid) | Acc:56.5%(valid) Epoch: 7 | time in 0 minute,25econds Loss: 0.0768(train) | Acc:49.5%(train) Loss: 0.0004(valid) | Acc:37.1%(valid) Epoch: 8 | time in 0 minute,25econds Loss: 0.0754(train) | Acc:52.0%(train) Loss: 0.0004(valid) | Acc:58.4%(valid) Epoch: 9 | time in 0 minute,25econds Loss: 0.0741(train) | Acc:54.2%(train) Loss: 0.0004(valid) | Acc:48.4%(valid) Epoch: 10 | time in 0 minute,25econds Loss: 0.0731(train) | Acc:56.0%(train) Loss: 0.0005(valid) | Acc:34.3%(valid) Epoch: 11 | time in 0 minute,25econds Loss: 0.0716(train) | Acc:58.4%(train) Loss: 0.0003(valid) | Acc:68.2%(valid) Epoch: 12 | time in 0 minute,25econds Loss: 0.0706(train) | Acc:60.1%(train) Loss: 0.0003(valid) | Acc:59.5%(valid) Epoch: 13 | time in 0 minute,26econds Loss: 0.0694(train) | Acc:62.2%(train) Loss: 0.0003(valid) | Acc:69.2%(valid) Epoch: 14 | time in 0 minute,25econds Loss: 0.0684(train) | Acc:63.9%(train) Loss: 0.0003(valid) | Acc:64.7%(valid) Epoch: 15 | time in 0 minute,25econds Loss: 0.0675(train) | Acc:65.4%(train) Loss: 0.0004(valid) | Acc:65.5%(valid) Epoch: 16 | time in 0 minute,25econds Loss: 0.0664(train) | Acc:67.2%(train) Loss: 0.0003(valid) | Acc:64.6%(valid) Epoch: 17 | time in 0 minute,25econds Loss: 0.0657(train) | Acc:68.5%(train) Loss: 0.0003(valid) | Acc:70.4%(valid) Epoch: 18 | time in 0 minute,25econds Loss: 0.0650(train) | Acc:69.5%(train) Loss: 0.0004(valid) | Acc:69.0%(valid) Epoch: 19 | time in 0 minute,25econds Loss: 0.0643(train) | Acc:70.8%(train) Loss: 0.0003(valid) | Acc:76.3%(valid) Epoch: 20 | time in 0 minute,25econds Loss: 0.0636(train) | Acc:72.0%(train) Loss: 0.0003(valid) | Acc:70.5%(valid)
查看embedding层嵌入的词向量
1
| print(model.state_dict()['embedding.weight'])
|
tensor([[-3.9321e-02, 1.2770e-02, -1.2725e-02, ..., -3.7640e-02, 5.0681e-02, 3.4286e-03], [-1.6661e+00, -5.6520e+00, -6.9105e-03, ..., -7.4342e-01, 1.5925e+00, -3.9538e-01], [ 1.2449e+00, 1.8321e+00, 8.1467e-01, ..., 4.5453e-01, -1.1000e+00, 8.3954e-01], ..., [-2.6404e-01, -4.9704e-01, 8.3933e-02, ..., -4.8199e-01, 3.0737e-01, 4.4653e-01], [-1.4254e-01, 2.1912e-01, -3.5175e-01, ..., 1.7252e-01, -4.0052e-01, -1.5885e-02], [-1.5442e-02, -2.2085e-01, -3.8362e-01, ..., -3.5968e-01, 3.6406e-01, 3.7704e-01]], device='cuda:0')