本文转载自:https://wanghao.blog.csdn.net/article/details/127387556

摘要

这篇文章主要讲如何从VOC和COCO数据集中提取特定的类,比如人。我们想做个行人检测的项目,需要从一些公开的数据集中提取一些行人的数据做补充。

1、提取VOC数据集

# -*- coding: utf-8 -*-
# @Function:There are 20 classes in VOC data set. If you need to extract specific classes, you can use this program to extract them.

import os
import shutil

ann_filepath = r"./VOCdevkit/VOC2012/Annotations/"
img_filepath = r"./VOCdevkit/VOC2012/JPEGImages/"
img_savepath = r"./VOCdevkit/VOC2012/tte/"
ann_savepath = r"./VOCdevkit/VOC2012/xml/"


if not os.path.exists(img_savepath):
    os.mkdir(img_savepath)

if not os.path.exists(ann_savepath):
    os.mkdir(ann_savepath)
names = locals()
classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
           'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
           'dog', 'horse', 'motorbike', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor', 'person']

for file in os.listdir(ann_filepath):
    print(file)
    fp = open(ann_filepath + '\\' + file)
    ann_savefile = ann_savepath + file
    fp_w = open(ann_savefile, 'w')
    lines = fp.readlines()

    ind_start = []
    ind_end = []
    lines_id_start = lines[:]
    lines_id_end = lines[:]

    classes5 = '\t\t<name>person</name>\n'

    # 在xml中找到object块,并将其记录下来
    while "\t<object>\n" in lines_id_start:
        a = lines_id_start.index("\t<object>\n")
        ind_start.append(a)
        lines_id_start[a] = "delete"

    while "\t</object>\n" in lines_id_end:
        b = lines_id_end.index("\t</object>\n")
        ind_end.append(b)
        lines_id_end[b] = "delete"

    # names中存放所有的object块
    i = 0
    for k in range(0, len(ind_start)):
        names['block%d' % k] = []
        for j in range(0, len(classes)):
            if classes[j] in lines[ind_start[i] + 1]:
                a = ind_start[i]
                for o in range(ind_end[i] - ind_start[i] + 1):
                    names['block%d' % k].append(lines[a + o])
                break
        i += 1
        # print(names['block%d' % k])

    # xml头
    string_start = lines[0:ind_start[0]]
    # xml尾
    string_end = [lines[len(lines) - 1]]

    # 在给定的类中搜索,若存在则,写入object块信息
    a = 0
    for k in range(0, len(ind_start)):

        if classes5 in names['block%d' % k]:
            a += 1
            string_start += names['block%d' % k]
    string_start += string_end
    for c in range(0, len(string_start)):
        fp_w.write(string_start[c])
    fp_w.close()
    # 如果没有我们寻找的模块,则删除此xml,有的话拷贝图片
    if a == 0:
        os.remove(ann_savepath + file)
    else:
        name_img = img_filepath + os.path.splitext(file)[0] + ".jpg"
        shutil.copy(name_img, img_savepath)
    fp.close()

2、从COCO中提取特定的类别

从coco数据集中提取特定的类,并将其转为VOC格式的xml文件保存。

  1. from pycocotools.coco import COCO
    import os
    import shutil
    from tqdm import tqdm
    import skimage.io as io
    import matplotlib.pyplot as plt
    import cv2
    from PIL import Image, ImageDraw
    
    # the path you want to save your results for coco to voc
    savepath = "./coco2017/result/"
    img_dir = savepath + 'images/'
    anno_dir = savepath + 'Annotations/'
    datasets_list = ['train2017','val2017']
    
    classes_names = ['person']
    dataDir = './coco2017/annotations_trainval2017'
    
    headstr = """\
    <annotation>
        <folder>VOC</folder>
        <filename>%s</filename>
        <source>
            <database>My Database</database>
            <annotation>COCO</annotation>
            <image>flickr</image>
            <flickrid>NULL</flickrid>
        </source>
        <owner>
            <flickrid>NULL</flickrid>
            <name>company</name>
        </owner>
        <size>
            <width>%d</width>
            <height>%d</height>
            <depth>%d</depth>
        </size>
        <segmented>0</segmented>
    """
    objstr = """\
        <object>
            <name>%s</name>
            <pose>Unspecified</pose>
            <truncated>0</truncated>
            <difficult>0</difficult>
            <bndbox>
                <xmin>%d</xmin>
                <ymin>%d</ymin>
                <xmax>%d</xmax>
                <ymax>%d</ymax>
            </bndbox>
        </object>
    """
    
    tailstr = '''\
    </annotation>
    '''
    
    
    # if the dir is not exists,make it,else delete it
    def mkr(path):
        if os.path.exists(path):
            shutil.rmtree(path)
            os.makedirs(path)
        else:
            os.makedirs(path)
    
    
    mkr(img_dir)
    mkr(anno_dir)
    
    
    def id2name(coco):
        classes = dict()
        for cls in coco.dataset['categories']:
            classes[cls['id']] = cls['name']
        return classes
    
    
    def write_xml(anno_path, head, objs, tail):
        f = open(anno_path, "w")
        f.write(head)
        for obj in objs:
            f.write(objstr % (obj[0], obj[1], obj[2], obj[3], obj[4]))
        f.write(tail)
    
    
    def save_annotations_and_imgs(coco, dataset, filename, objs):
        # eg:COCO_train2014_000000196610.jpg-->COCO_train2014_000000196610.xml
        anno_path = anno_dir + filename[:-3] + 'xml'
        img_path = dataDir + '/'+dataset + '/' + filename
        print("img_path:",img_path)
        dst_imgpath = img_dir + filename
    
        img = cv2.imread(img_path)
        if (img.shape[2] == 1):
            print(filename + " not a RGB image")
            return
        shutil.copy(img_path, dst_imgpath)
    
        head = headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
        tail = tailstr
        write_xml(anno_path, head, objs, tail)
    
    
    def showimg(coco, dataset, img, classes, cls_id, show=True):
        global dataDir
        I = Image.open('%s/%s/%s' % (dataDir, dataset, img['file_name']))
        # 通过id,得到注释的信息
        annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
        # print(annIds)
        anns = coco.loadAnns(annIds)
        # print(anns)
        # coco.showAnns(anns)
        objs = []
        for ann in anns:
            class_name = classes[ann['category_id']]
            if class_name in classes_names:
                print(class_name)
                if 'bbox' in ann:
                    bbox = ann['bbox']
                    xmin = int(bbox[0])
                    ymin = int(bbox[1])
                    xmax = int(bbox[2] + bbox[0])
                    ymax = int(bbox[3] + bbox[1])
                    obj = [class_name, xmin, ymin, xmax, ymax]
                    objs.append(obj)
                    draw = ImageDraw.Draw(I)
                    draw.rectangle([xmin, ymin, xmax, ymax])
        if show:
            plt.figure()
            plt.axis('off')
            plt.imshow(I)
            plt.show()
    
        return objs
    
    
    for dataset in datasets_list:
        # ./COCO/annotations/instances_train2014.json
        annFile = '{}/annotations/instances_{}.json'.format(dataDir, dataset)
    
        # COCO API for initializing annotated data
        coco = COCO(annFile)
        '''
        COCO 对象创建完毕后会输出如下信息:
        loading annotations into memory...
        Done (t=0.81s)
        creating index...
        index created!
        至此, json 脚本解析完毕, 并且将图片和对应的标注数据关联起来.
        '''
        # show all classes in coco
        classes = id2name(coco)
        print(classes)
        # [1, 2, 3, 4, 6, 8]
        classes_ids = coco.getCatIds(catNms=classes_names)
        print(classes_ids)
        for cls in classes_names:
            # Get ID number of this class
            cls_id = coco.getCatIds(catNms=[cls])
            img_ids = coco.getImgIds(catIds=cls_id)
            print(cls, len(img_ids))
            # imgIds=img_ids[0:10]
            for imgId in tqdm(img_ids):
                img = coco.loadImgs(imgId)[0]
                filename = img['file_name']
                print("filename:",filename)
                print("dataset:",dataset)
                objs = showimg(coco, dataset, img, classes, classes_ids, show=False)
                print(objs)
                save_annotations_and_imgs(coco, dataset, filename, objs)