新技能get,欢迎点赞!
近期PyTorch发布了新的版本1.11,和这次新版本同时的发布还有两个新的torch库:TorchData和functorch,其中TorchData对标TensorFlow的tf.data,而functorch对标谷歌的JAX。TorchData提供了一种新的数据构建方式:DataPipes,以替代PyTorch现有的torch.utils.data.Dataset,这篇文章将简单介绍两种方式的区别以及新方式的主要用法。
Dataset
PyTorch现有的数据构建方式是:首先定义一个Dataset,然后将Dataset送入DataLoader中,下面是一个具体的实例:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import DataLoader
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
train_dataset = CustomImageDataset("annotations_file", "images")
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
for images, targets in train_dataloader:
pass
上面的例子我们实现的是一个map-style datasets(Dataset),这里需要实现 getitem() 和 len() 两个接口,这样就可以通过索引来获取样例:dataset[idx]。Map-style datasets是我们最常用的dataset,还有另外一种dataset是iterable-style-datasets(IterableDataset),它需要实现__iter__()接口,其实这两种数据类型都属于iterable,即可迭代对象,可迭代对象可以通过调用iter()函数生成迭代器(iterator),也可以直接用for循环来进行遍历。两种datasets适用场景有所区别:map-style dataset适合已知所有样本(比如上面的图像路径和标签),并且可以全部加载到内存中,这样可以通过索引来随机读取;但当随机读取成本较高或者不可能时,iterable-style dataset就比较适合了,比如实时产生的流式数据。从本质上说,iterable-style dataset更通用,因为map-style dataset都可以转化为iterable-style dataset,但是对于很多实际场景中,我们已知全部数据集,往往会采用map-style dataset,它的一个好处是可以用torch.utils.data.Sampler来控制采样,比如在分布式训练中用torch.utils.data.distributed.DistributedSampler来拆分数据,而iterable-style dataset则不支持samper,此外在多进程读取时,iterable-style dataset也要单独处理防止数据重复加载,下面是torch.utils.data.IterableDataset文档的一个例子,这里采用torch.utils.data.get_worker_info来拆分数据:
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))
DataPipes
对于Dataset这种数据构建方式,它的好处是比较灵活:只需要继承Dataset基类,然后实现相关的接口协议即可。但是正是这种灵活,让代码难以复用,不同的代码库或者不同的数据集都维护着自己的一套代码,它们本身有很多重复的工作。这也正是TorchData推出的意义,TorchData引入DataPipes来替代现有的Dataset,它是一种可组合的数据加载和构建流程,这和TensorFlow的tf.data非常相似,可以将DataPipes看成可组合的Dataset,这使得很多操作比如读取和解析都模块化,从而增加代码复用性。
和原有的Dataset一样,DataPipes也有两种类型:IterDataPipes和MapDataPipes,其中IterDataPipes对标IterableDataset,需要实现__iter__接口;而MapDataPipes对标Dataset,需要实现__getitem__() 和 len() 两个接口。每个DataPipe都实现一个特殊的功能,比如用于json文件解析的JsonParserIterDataPipe:
import json
class JsonParserIterDataPipe(IterDataPipe):
def __init__(self, source_datapipe, **kwargs) -> None:
self.source_datapipe = source_datapipe
self.kwargs = kwargs
def __iter__(self):
for file_name, stream in self.source_datapipe:
data = stream.read()
yield file_name, json.loads(data)
def __len__(self):
return len(self.source_datapipe)
TorchData库已经内置了很多的DataPipes,我们可以组合不同的IDataPipes来构建数据,下面为一个iterDataPipes的一个简单例子,这里用IterableWrapper创建了一个源iterDataPipe,然后可以链式地组合不同的iterDataPipes来实现更复杂的数据流,比如Mapper,Filter,Shuffler和Batcher。对于每个DataPipe,其构造函数的第一个参数是上一个DataPipe,这是链式组合的基础,大部分的DataPipes都有对应的函数格式,比如Mapper对应的函数是map,所以有两种组合DataPipes方式:一是适用类构造器比如Mapper(dp, lambda x: x + 1),二是采用函数式比如dp.map(lambda x: x + 1) ,函数式是推荐采用的方式,它更简单易懂。
>>> from torchdata.datapipes.iter import IterableWrapper, Mapper
>>> dp = IterableWrapper(range(10))
>>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor
>>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended)
>>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> list(map_dp_2)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0)
>>> list(filter_dp)
[2, 4, 6, 8, 10]
>>> shuffle_dp = filter_dp.shuffle()
>>> list(shuffle_dp)
[2, 6, 10, 4, 8]
>>> batch_dp = shuffle_dp.batch(2)
>>> list(batch_dp)
[[4, 6], [8, 2], [10]]
当组合不同的DataPipes后,其实我们是得到了一个Graph,从某种意义说,原来的Dataset其实就相当于一个包含不同DataPipes的Graph。我们也可以通过一些辅助函数来得到Graph中所包含的DataPipes:
>>> graph = torch.utils.data.graph.traverse(batch_dp, only_datapipe=True)
>>> torch.utils.data.graph_settings.get_all_graph_pipes(graph) # 获取所有的datapipes
{<torch.utils.data.datapipes.iter.callable.MapperIterDataPipe at 0x7f50b2f86700>,
<torch.utils.data.datapipes.iter.combinatorics.ShufflerIterDataPipe at 0x7f50b0fccd00>,
<torch.utils.data.datapipes.iter.grouping.BatcherIterDataPipe at 0x7f50b137b490>,
<torch.utils.data.datapipes.iter.selecting.FilterIterDataPipe at 0x7f50b2fcc6d0>,
<torch.utils.data.datapipes.iter.utils.IterableWrapperIterDataPipe at 0x7f50b2f86490>}
构建好DataPipes后,可以像Dataset一样送入DataLoader中来使用:
>>> dl = torch.utils.data.DataLoader(batch_dp)
>>> for batch in dl: print(batch)
[tensor([2]), tensor([4])]
[tensor([6]), tensor([8])]
[tensor([10])]
对于IterDataPipes,特别要注意的一点是,它和IterDataset一样,当使用多进程的DataLoader即num_workers>0时,要特别处理防止数据重复加载,或者在__iter__采用torch.utils.data.get_worker_info来拆分数据,或者在DataLoader的 worker_init_fn函数中操作,目前PyTorch提供了一个兼容的函数来实现这样的功能:
dl = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=2,
worker_init_fn=torch.utils.data.backward_compatibility.worker_init_fn,
)
对于MapDataPipes,它除了和IterDataPipes实现的接口不同,其它用法基本类似,如下所示:
>>> from torchdata.datapipes.map import SequenceWrapper, Mapper
>>> dp = SequenceWrapper(range(10))
>>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended)
>>> list(map_dp_1)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor
>>> list(map_dp_2)
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
>>> batch_dp = map_dp_1.batch(batch_size=2)
>>> list(batch_dp)
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
>>> batch_dp[1] # 使用索引来获取样例
[3, 4]
>>> dl = torch.utils.data.DataLoader(batch_dp)
>>> for batch in dl: print(batch)
[tensor([1]), tensor([2])]
[tensor([3]), tensor([4])]
[tensor([5]), tensor([6])]
[tensor([7]), tensor([8])]
[tensor([9]), tensor([10])]
MapDataPipes和MapDataset一样,当送入Dataloader后,它是通过sampler来控制数据的读取。目前TorchData内置的MapDataPipes相比IterDataPipes还较少。
TorchData中已经内置了很多DataPipes(MapDataPipes和IterDataPipes),虽然它们覆盖了大部分的常用操作,但我们还可以自定义新的DataPipes,这里以最简单的MapperIterDataPipe为例,它实现的是map功能,主要有两个要点:一是构建函数的第一个参数是source_dp,二是实现__iter__接口,此外,我们还可以用functional_datapipe来注册它对应的函数格式。
from torchdata.datapipes.iter import IterDataPipe
from torch.utils.data.datapipes._decorator import
@functional_datapipe("map")
class MapperIterDataPipe(IterDataPipe):
def __init__(self, source_dp: IterDataPipe, fn) -> None:
super().__init__()
self.source_dp = source_dp
self.fn = fn
def __iter__(self):
for d in self.dp:
yield self.fn(d)
完整示例
这里以官方的教程给出一个完整的例子,首先我们随机生成一些CSV数据:它包含label和多个特征。
import csv
import random
def generate_csv(file_label, num_rows: int = 5000, num_features: int = 20) -> None:
fieldnames = ['label'] + [f'c{i}' for i in range(num_features)]
writer = csv.DictWriter(open(f"sample_data{file_label}.csv", "w"), fieldnames=fieldnames)
writer.writerow({col: col for col in fieldnames}) # writing the header row
for i in range(num_rows):
row_data = {col: random.random() for col in fieldnames}
row_data['label'] = random.randint(0, 9)
writer.writerow(row_data)
num_files_to_generate = 3
for i in range(num_files_to_generate):
generate_csv(file_label=i)
然后我们构建解析CSV的DataPipes,如下所示:
import numpy as np
import torchdata.datapipes as dp
def build_datapipes(root_dir="."):
# 获取所有文件
datapipe = dp.iter.FileLister(root_dir)
# 筛选csv文件
datapipe = datapipe.filter(filter_fn=lambda filename: "sample_data" in filename and filename.endswith(".csv"))
# 打开文件,FileOpener没有对应的函数格式,如果安装了iopath,可以使用
# datapipe = datapipe.open_by_iopath(mode='rt')
datapipe = dp.iter.FileOpener(datapipe, mode='rt')
# 解析csv
datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
# 分离label和特征
datapipe = datapipe.map(lambda row: {"label": np.array(row[0], np.int32),
"data": np.array(row[1:], dtype=np.float64)})
return datapipe
datapipe = build_datapipes()
然后我们将构建好的datapipes送入DataLoader中进行使用:
dl = DataLoader(dataset=datapipe, batch_size=50, shuffle=True)
first = next(iter(dl))
labels, features = first['label'], first['data']
print(f"Labels batch shape: {labels.size()}")
print(f"Feature batch shape: {features.size()}")
--------------------------------------------------
Labels batch shape: 50
Feature batch shape: torch.Size([50, 20])
TorchData也提供了更多的例子,涉及到图像,文本和语音,具体见https://pytorch.org/data/beta/examples.html。
小结
虽然目前TorchData只是Beta版本,但是它必然是未来的趋势,因为目前PyTorch的三个主要应用库:torchvision,torchaudio和torchtext都已经开始接入TorchData了。此外,PyTorch官方也在开发DataLoader V2,新的版本将更加专注多进程和分布式等功能,而去除数据处理的一些功能,比如batch和shuffle,这其实也是对TorchData的更佳适配。
评论(0)
您还未登录,请登录后发表或查看评论