Skip to content

PyTorch数据集制作

简介

PyTorch提供了强大的数据集(Dataset)和数据加载器(DataLoader)API,用于处理各种类型的数据。本文档将介绍如何使用PyTorch创建自定义数据集,特别是针对NLP情感分析任务。

Dataset类基本结构

在PyTorch中,创建自定义数据集需要继承torch.utils.data.Dataset类并实现以下方法:

  1. __init__: 初始化数据集,加载数据
  2. __len__: 返回数据集中样本的数量
  3. __getitem__: 根据索引返回单个样本

情感分析数据集示例

下面是一个简单的情感分析数据集实现示例:

python
from torch.utils.data import Dataset, DataLoader
import torch

class TestDateset(Dataset):
    def __init__(self, samples):
        """
        初始化情感分析数据集
        Args:
            samples: 样本列表,每个样本包含句子和标签
        """
        self.samples = samples

    def __len__(self):
        """返回数据集大小"""
        return len(self.samples)
    
    def __getitem__(self, idx):
        """根据索引返回单个样本"""
        sample = self.samples[idx]
        sentence = sample["sentence"]
        label = sample["label"]
        
        # 直接返回句子文本和标签
        return sentence, label

数据准备

对于情感分析任务,我们可以准备不同情感类别的样本:

python
# 正样本
positive_samples = [
    {"sentence": "我非常喜欢这部电影", "label": 1},
    {"sentence": "这个产品质量很好", "label": 1},
    {"sentence": "服务态度非常棒", "label": 1},
    {"sentence": "真是太棒了", "label": 1},
    {"sentence": "我非常满意", "label": 1},
]

# 负样本
negative_samples = [
    {"sentence": "我非常讨厌这部电影", "label": 0},
    {"sentence": "这个产品质量很差", "label": 0},
    {"sentence": "服务态度极其糟糕", "label": 0},
    {"sentence": "真是太失望了", "label": 0},
    {"sentence": "我感到非常不满", "label": 0},
]

# 中性样本
neutral_samples = [
    {"sentence": "这部电影一般般", "label": 2},
    {"sentence": "这个产品还行吧", "label": 2},
    {"sentence": "服务态度一般", "label": 2},
    {"sentence": "没什么特别的感觉", "label": 2},
    {"sentence": "不好不坏吧", "label": 2}
]

# 合并所有样本
all_samples = positive_samples + negative_samples + neutral_samples

DataLoader使用

创建数据集后,可以使用DataLoader进行批量加载:

python
# 创建数据集实例
dataset = TestDateset(all_samples)

# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# 遍历批次数据
for batch_sentences, batch_labels in dataloader:
    # 处理批次数据
    print(f"批次大小: {len(batch_sentences)}")
    print(f"第一个句子: {batch_sentences[0]}")
    print(f"第一个标签: {batch_labels[0]}")
    # 在模型训练中使用...

数据集增强功能

可以为基本数据集添加更多功能:

  1. 标签分布统计
python
def get_label_distribution(self):
    """获取标签分布情况"""
    label_count = {}
    for sample in self.samples:
        label = sample["label"]
        label_count[label] = label_count.get(label, 0) + 1
    return label_count

完整示例

完整的情感分析数据集示例可以在pytorchLearn/testDateset.py文件中找到,它展示了如何实现一个简单的数据集类,并使用DataLoader加载批量数据。

总结

创建自定义PyTorch数据集的关键点:

  1. 继承Dataset类并实现必要的方法
  2. 根据任务需求准备和组织数据
  3. 使用DataLoader批量加载数据
  4. 根据需要添加数据预处理和增强功能

注意问题

1. __getitem__ 里面转张量的问题

    1. DataLoader 的默认行为 PyTorch 的 DataLoader 默认会尝试把 batch 里的每一项“拼接”成张量,但前提是这些项本身可以直接拼成张量。 对于 batch_labels,它们本来就是 int 或 tensor,能直接拼成一个一维张量。 对于 batch_sentences,它们是字符串(文本),字符串之间无法直接拼成张量,所以 DataLoader 会把它们组合成一个元组或列表(通常是元组)。
    1. 如果你希望 batch_sentences 也能变成张量,需要在数据集的 getitem 方法里就把句子转成张量(比如转成索引序列),而不是直接返回字符串。
python
class TestDateset(Dataset):
    def __init__(self, samples):
        self.samples = samples

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        """根据索引返回单个样本"""
        sample = self.samples[idx]
        sentence = sample["sentence"]
        label = sample["label"]

        # 假设你有一个 vocab,把句子转成索引
        indices = [vocab.get(char, vocab["<UNK>"]) for char in sentence]
        # 填充到固定长度
        indices = indices + [vocab["<PAD>"]] * (max_length - len(indices))
        indices_tensor = torch.tensor(indices, dtype=torch.long)
        return indices_tensor, label