TensorBoard是一款优秀的基于浏览器的机器学习可视化工具。之前是tensorflow的御用可视化工具,由于tensorboard并不是直接读取tf张量,而是读取log进行可视化。所以,其他框架只需生成tensorboard可读的log,即可完成可视化

之前,我一直用visdom做pytorch可视化,也是非常易用。不过现在跟tensorboard对比,我还是更推荐tensorboard

visdom相比tensorboard只有一个优点,那就是自动实时刷新。而tensorboard无论从可视化美观性、可视化数据多样性等多个方面,都碾压visdom。甚至,tensorboard更加易用一些。

先给一个官方文档链接:https://pytorch.org/docs/stable/tensorboard.html

tensorboard的安装就不在这篇文章讲述了。安装1.15以上的版本即可。


1. 标量(scalars)数据可视化

标量就是数字,咱们训练过程中的loss值,测试集的accuracy,包括precision和recall等等都可以通过这个方式画出曲线。可以更直观地反映模型的训练情况。还在拿matplotlib可视化loss曲线的童鞋可以换家伙什儿了。

用tensorboard做标量可视化非常简单:

from torch.utils.tensorboard import SummaryWriter
 
log_writer = SummaryWriter()
 
def train(xxx):
    for epoch in epochs:
        loss = xxx
        log_writer.add_scalar('Loss/train', float(loss), epoch)

首先咱们需要一个写log的东西——log_writer, 然后直接add_scalar就可以了。

add_scalar('Loss/train', float(loss), epoch),第一个参数是名称,第二个参数是y值,第三个参数是x值。(用x,y画图,不用我解释x,y是啥吧?)

也就比原来的训练代码多了三行,即可收集训练的loss,用以可视化。

咱们可以先把加了三行的训练代码跑起来,中途生成的loss都会保存在当前目录下一个名为'runs/'的文件夹中。当然这个文件夹可以自定义,'runs/'只是默认名。

第二步,咱们就打开tensorboard瞅一瞅。再打开一个terminal,输入:

tensorboard --logdir=runs/

运行以后,会给你一个链接,用浏览器打开即可。一般为 https://127.0.0.1:6006,6006端口被占用的话会是另一个端口。

我实际运行的一个loss如下:


看着还是很酷炫的,可以点击右上角刷新,以查看实时训练情况。

2. 图(GRAPH)数据可视化

这个可以用来可视化网络结构,不太涉及动态变化,所以甚至比标量可视化更加简单。直接用add_graph就可以完成,

需要注意的是要定义输入的shape,类似于tf的placeholder。我们看一个官方栗子:

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
 
# Writer will output to ./runs/ directory by default
writer = SummaryWriter()
 
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = torchvision.models.resnet50(False)
# Have ResNet model take in grayscale rather than RGB
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
images, labels = next(iter(trainloader))
 
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
writer.close()

注意,上面的add_graph有2个输入参数,一个是模型,另一个就是类似于placeholder的东西,用来描述输入的shape。因为可视化网络结构的时候,后台会帮你计算出每一层feature map的尺寸,但这尺寸都与输入的shape有关。

文章知识点与官方知识档案匹配,可进一步学习相关知识

Python入门技能树人工智能深度学习