新技能get,欢迎点赞!


近期PyTorch发布了新的版本1.11,和这次新版本同时的发布还有两个新的torch库:TorchDatafunctorch,其中TorchData对标TensorFlow的tf.data,而functorch对标谷歌的JAX。TorchData提供了一种新的数据构建方式:DataPipes,以替代PyTorch现有的torch.utils.data.Dataset,这篇文章将简单介绍两种方式的区别以及新方式的主要用法。

Dataset

PyTorch现有的数据构建方式是:首先定义一个Dataset,然后将Dataset送入DataLoader中,下面是一个具体的实例:

import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import DataLoader

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
    

train_dataset = CustomImageDataset("annotations_file", "images")
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

for images, targets in train_dataloader:
    pass

上面的例子我们实现的是一个map-style datasetsDataset),这里需要实现 getitem() 和 len() 两个接口,这样就可以通过索引来获取样例:dataset[idx]。Map-style datasets是我们最常用的dataset,还有另外一种dataset是iterable-style-datasetsIterableDataset),它需要实现__iter__()接口,其实这两种数据类型都属于iterable,即可迭代对象,可迭代对象可以通过调用iter()函数生成迭代器(iterator),也可以直接用for循环来进行遍历。两种datasets适用场景有所区别:map-style dataset适合已知所有样本(比如上面的图像路径和标签),并且可以全部加载到内存中,这样可以通过索引来随机读取;但当随机读取成本较高或者不可能时,iterable-style dataset就比较适合了,比如实时产生的流式数据。从本质上说,iterable-style dataset更通用,因为map-style dataset都可以转化为iterable-style dataset,但是对于很多实际场景中,我们已知全部数据集,往往会采用map-style dataset,它的一个好处是可以用torch.utils.data.Sampler来控制采样,比如在分布式训练中用torch.utils.data.distributed.DistributedSampler来拆分数据,而iterable-style dataset则不支持samper,此外在多进程读取时,iterable-style dataset也要单独处理防止数据重复加载,下面是torch.utils.data.IterableDataset文档的一个例子,这里采用torch.utils.data.get_worker_info来拆分数据:

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
    super(MyIterableDataset).__init__()
    assert end > start, "this example code only works with end >= start"
    self.start = start
    self.end = end
    
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        return iter(range(iter_start, iter_end))

DataPipes

对于Dataset这种数据构建方式,它的好处是比较灵活:只需要继承Dataset基类,然后实现相关的接口协议即可。但是正是这种灵活,让代码难以复用,不同的代码库或者不同的数据集都维护着自己的一套代码,它们本身有很多重复的工作。这也正是TorchData推出的意义,TorchData引入DataPipes来替代现有的Dataset,它是一种可组合的数据加载和构建流程,这和TensorFlow的tf.data非常相似,可以将DataPipes看成可组合的Dataset,这使得很多操作比如读取和解析都模块化,从而增加代码复用性。

和原有的Dataset一样,DataPipes也有两种类型:IterDataPipes和MapDataPipes,其中IterDataPipes对标IterableDataset,需要实现__iter__接口;而MapDataPipes对标Dataset,需要实现__getitem__() 和 len() 两个接口。每个DataPipe都实现一个特殊的功能,比如用于json文件解析的JsonParserIterDataPipe:

import json

class JsonParserIterDataPipe(IterDataPipe):
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs

    def __iter__(self):
        for file_name, stream in self.source_datapipe:
            data = stream.read()
            yield file_name, json.loads(data)

    def __len__(self):
        return len(self.source_datapipe)

TorchData库已经内置了很多的DataPipes,我们可以组合不同的IDataPipes来构建数据,下面为一个iterDataPipes的一个简单例子,这里用IterableWrapper创建了一个源iterDataPipe,然后可以链式地组合不同的iterDataPipes来实现更复杂的数据流,比如Mapper,Filter,Shuffler和Batcher。对于每个DataPipe,其构造函数的第一个参数是上一个DataPipe,这是链式组合的基础,大部分的DataPipes都有对应的函数格式,比如Mapper对应的函数是map,所以有两种组合DataPipes方式:一是适用类构造器比如Mapper(dp, lambda x: x + 1),二是采用函数式比如dp.map(lambda x: x + 1) ,函数式是推荐采用的方式,它更简单易懂

>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
>>> dp = IterableWrapper(range(10))
>>> map_dp_1 = Mapper(dp, lambda x: x + 1)  # Using class constructor
>>> map_dp_2 = dp.map(lambda x: x + 1)  # Using functional form (recommended)
>>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> list(map_dp_2)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
>>> list(filter_dp)
[2, 4, 6, 8, 10]
>>> shuffle_dp = filter_dp.shuffle()
>>> list(shuffle_dp)
[2, 6, 10, 4, 8]
>>> batch_dp = shuffle_dp.batch(2)
>>> list(batch_dp)
[[4, 6], [8, 2], [10]]

当组合不同的DataPipes后,其实我们是得到了一个Graph,从某种意义说,原来的Dataset其实就相当于一个包含不同DataPipes的Graph。我们也可以通过一些辅助函数来得到Graph中所包含的DataPipes:

>>> graph = torch.utils.data.graph.traverse(batch_dp, only_datapipe=True)
>>> torch.utils.data.graph_settings.get_all_graph_pipes(graph) # 获取所有的datapipes
{<torch.utils.data.datapipes.iter.callable.MapperIterDataPipe at 0x7f50b2f86700>,
 <torch.utils.data.datapipes.iter.combinatorics.ShufflerIterDataPipe at 0x7f50b0fccd00>,
 <torch.utils.data.datapipes.iter.grouping.BatcherIterDataPipe at 0x7f50b137b490>,
 <torch.utils.data.datapipes.iter.selecting.FilterIterDataPipe at 0x7f50b2fcc6d0>,
 <torch.utils.data.datapipes.iter.utils.IterableWrapperIterDataPipe at 0x7f50b2f86490>}

构建好DataPipes后,可以像Dataset一样送入DataLoader中来使用:

>>> dl = torch.utils.data.DataLoader(batch_dp)
>>> for batch in dl: print(batch)
[tensor([2]), tensor([4])]
[tensor([6]), tensor([8])]
[tensor([10])]

对于IterDataPipes,特别要注意的一点是,它和IterDataset一样,当使用多进程的DataLoader即num_workers>0时,要特别处理防止数据重复加载,或者在__iter__采用torch.utils.data.get_worker_info来拆分数据,或者在DataLoader的 worker_init_fn函数中操作,目前PyTorch提供了一个兼容的函数来实现这样的功能:

  dl = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn,
    )

对于MapDataPipes,它除了和IterDataPipes实现的接口不同,其它用法基本类似,如下所示:

>>> from torchdata.datapipes.map import SequenceWrapper, Mapper
>>> dp = SequenceWrapper(range(10))
>>> map_dp_1 = dp.map(lambda x: x + 1)  # Using functional form (recommended)
>>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> map_dp_2 = Mapper(dp, lambda x: x + 1)  # Using class constructor
>>> list(map_dp_2)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> batch_dp = map_dp_1.batch(batch_size=2)
>>> list(batch_dp)
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
>>> batch_dp[1]  # 使用索引来获取样例
[3, 4]
>>> dl = torch.utils.data.DataLoader(batch_dp)
>>> for batch in dl: print(batch)
[tensor([1]), tensor([2])]
[tensor([3]), tensor([4])]
[tensor([5]), tensor([6])]
[tensor([7]), tensor([8])]
[tensor([9]), tensor([10])]

MapDataPipes和MapDataset一样,当送入Dataloader后,它是通过sampler来控制数据的读取。目前TorchData内置的MapDataPipes相比IterDataPipes还较少。

TorchData中已经内置了很多DataPipes(MapDataPipesIterDataPipes),虽然它们覆盖了大部分的常用操作,但我们还可以自定义新的DataPipes,这里以最简单的MapperIterDataPipe为例,它实现的是map功能,主要有两个要点:一是构建函数的第一个参数是source_dp,二是实现__iter__接口,此外,我们还可以用functional_datapipe来注册它对应的函数格式。

from torchdata.datapipes.iter import IterDataPipe
from torch.utils.data.datapipes._decorator import 

@functional_datapipe("map")
class MapperIterDataPipe(IterDataPipe):
    def __init__(self, source_dp: IterDataPipe, fn) -> None:
        super().__init__()
        self.source_dp = source_dp
        self.fn = fn
        
    def __iter__(self):
        for d in self.dp:
            yield self.fn(d)

完整示例

这里以官方的教程给出一个完整的例子,首先我们随机生成一些CSV数据:它包含label和多个特征。

import csv
import random

def generate_csv(file_label, num_rows: int = 5000, num_features: int = 20) -> None:
    fieldnames = ['label'] + [f'c{i}' for i in range(num_features)]
    writer = csv.DictWriter(open(f"sample_data{file_label}.csv", "w"), fieldnames=fieldnames)
    writer.writerow({col: col for col in fieldnames})  # writing the header row
    for i in range(num_rows):
        row_data = {col: random.random() for col in fieldnames}
        row_data['label'] = random.randint(0, 9)
        writer.writerow(row_data)
        
num_files_to_generate = 3
for i in range(num_files_to_generate):
    generate_csv(file_label=i)

然后我们构建解析CSV的DataPipes,如下所示:

import numpy as np
import torchdata.datapipes as dp

def build_datapipes(root_dir="."):
    # 获取所有文件
    datapipe = dp.iter.FileLister(root_dir)
    # 筛选csv文件
    datapipe = datapipe.filter(filter_fn=lambda filename: "sample_data" in filename and filename.endswith(".csv"))
    # 打开文件,FileOpener没有对应的函数格式,如果安装了iopath,可以使用
    # datapipe = datapipe.open_by_iopath(mode='rt')
    datapipe = dp.iter.FileOpener(datapipe, mode='rt')
    # 解析csv
    datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
    # 分离label和特征
    datapipe = datapipe.map(lambda row: {"label": np.array(row[0], np.int32),
                                         "data": np.array(row[1:], dtype=np.float64)})
    return datapipe

datapipe = build_datapipes()

然后我们将构建好的datapipes送入DataLoader中进行使用:

dl = DataLoader(dataset=datapipe, batch_size=50, shuffle=True)
first = next(iter(dl))
labels, features = first['label'], first['data']
print(f"Labels batch shape: {labels.size()}")
print(f"Feature batch shape: {features.size()}")

--------------------------------------------------
Labels batch shape: 50
Feature batch shape: torch.Size([50, 20])

TorchData也提供了更多的例子,涉及到图像,文本和语音,具体见pytorch.org/data/beta/e

小结

虽然目前TorchData只是Beta版本,但是它必然是未来的趋势,因为目前PyTorch的三个主要应用库:torchvision,torchaudio和torchtext都已经开始接入TorchData了。此外,PyTorch官方也在开发DataLoader V2,新的版本将更加专注多进程和分布式等功能,而去除数据处理的一些功能,比如batch和shuffle,这其实也是对TorchData的更佳适配。

参考