Search K
Appearance
Appearance
PyTorch提供了强大的数据集(Dataset)和数据加载器(DataLoader)API,用于处理各种类型的数据。本文档将介绍如何使用PyTorch创建自定义数据集,特别是针对NLP情感分析任务。
在PyTorch中,创建自定义数据集需要继承torch.utils.data.Dataset类并实现以下方法:
__init__: 初始化数据集,加载数据__len__: 返回数据集中样本的数量__getitem__: 根据索引返回单个样本下面是一个简单的情感分析数据集实现示例:
from torch.utils.data import Dataset, DataLoader
import torch
class TestDateset(Dataset):
def __init__(self, samples):
"""
初始化情感分析数据集
Args:
samples: 样本列表,每个样本包含句子和标签
"""
self.samples = samples
def __len__(self):
"""返回数据集大小"""
return len(self.samples)
def __getitem__(self, idx):
"""根据索引返回单个样本"""
sample = self.samples[idx]
sentence = sample["sentence"]
label = sample["label"]
# 直接返回句子文本和标签
return sentence, label对于情感分析任务,我们可以准备不同情感类别的样本:
# 正样本
positive_samples = [
{"sentence": "我非常喜欢这部电影", "label": 1},
{"sentence": "这个产品质量很好", "label": 1},
{"sentence": "服务态度非常棒", "label": 1},
{"sentence": "真是太棒了", "label": 1},
{"sentence": "我非常满意", "label": 1},
]
# 负样本
negative_samples = [
{"sentence": "我非常讨厌这部电影", "label": 0},
{"sentence": "这个产品质量很差", "label": 0},
{"sentence": "服务态度极其糟糕", "label": 0},
{"sentence": "真是太失望了", "label": 0},
{"sentence": "我感到非常不满", "label": 0},
]
# 中性样本
neutral_samples = [
{"sentence": "这部电影一般般", "label": 2},
{"sentence": "这个产品还行吧", "label": 2},
{"sentence": "服务态度一般", "label": 2},
{"sentence": "没什么特别的感觉", "label": 2},
{"sentence": "不好不坏吧", "label": 2}
]
# 合并所有样本
all_samples = positive_samples + negative_samples + neutral_samples创建数据集后,可以使用DataLoader进行批量加载:
# 创建数据集实例
dataset = TestDateset(all_samples)
# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# 遍历批次数据
for batch_sentences, batch_labels in dataloader:
# 处理批次数据
print(f"批次大小: {len(batch_sentences)}")
print(f"第一个句子: {batch_sentences[0]}")
print(f"第一个标签: {batch_labels[0]}")
# 在模型训练中使用...可以为基本数据集添加更多功能:
def get_label_distribution(self):
"""获取标签分布情况"""
label_count = {}
for sample in self.samples:
label = sample["label"]
label_count[label] = label_count.get(label, 0) + 1
return label_count完整的情感分析数据集示例可以在pytorchLearn/testDateset.py文件中找到,它展示了如何实现一个简单的数据集类,并使用DataLoader加载批量数据。
创建自定义PyTorch数据集的关键点:
Dataset类并实现必要的方法DataLoader批量加载数据__getitem__ 里面转张量的问题 class TestDateset(Dataset):
def __init__(self, samples):
self.samples = samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
"""根据索引返回单个样本"""
sample = self.samples[idx]
sentence = sample["sentence"]
label = sample["label"]
# 假设你有一个 vocab,把句子转成索引
indices = [vocab.get(char, vocab["<UNK>"]) for char in sentence]
# 填充到固定长度
indices = indices + [vocab["<PAD>"]] * (max_length - len(indices))
indices_tensor = torch.tensor(indices, dtype=torch.long)
return indices_tensor, label