描述

上一篇文章介绍了匈牙利匹配算法的原理,这一篇文章我们来分析另一种好用的匹配算法——KM算法

匈牙利匹配算法是一个二分图匹配算法,但对于任意两顶点的匹配都是等价的。显然,如果我们面对特征匹配任务时,特征A可能和特征B、特征C都产生了关联。如何在B与C之间选择更佳的匹配呢。这就是KM算法优于匈牙利匹配的地方。

KM算法原理

KM算法是在匈牙利算法的衍生算法,在二分图匹配的问题上增加权重,变成了一个带权二分图匹配问题,求最优的二分图匹配。

  • 什么是二分图的带权匹配?
    二分图的带权匹配就是求出一个匹配集合,使得集合中边的权值之和最大或最小。

而二分图的最优匹配则一定为完备匹配,在此基础上,才要求匹配的边权值之和最大或最小。二分图的带权匹配与最优匹配不等价,也不互相包含。

我们可以使用KM算法实现求二分图的最优匹配。KM算法可以实现为O(N^3)。

Apollo的使用

在Apollo代码中,使用匈牙利算法去做的是检测跟踪关联(Detection-to-Track Association)。

以下是Apollo github上的解释:

“当将检测与现有跟踪列表相关联时,Apollo构建了一个二分图,然后使用 匈牙利算法以最小成本(距离)找到最佳检测跟踪匹配。

首先,建立一个关联距离矩阵。根据一系列关联特征(包括运动一致性,外观一致性等)计算给定检测和一条轨迹之间的距离。HM跟踪器距离计算中使用这样一些特征:用于评估运动一致性的location_distance和direction_distance,用于评估外观一致性的bbox_size_distance、point_num_distance和histogram_distance。此外,还有一些重要的距离权重参数,用于将上述关联特征组合成最终距离测量。

给定关联距离矩阵,Apollo构造了一个二分图,并使用 匈牙利算法通过最小化距离成本找到最佳的检测跟踪匹配。它解决了O(n^3)时间复杂度中的赋值问题。 为了提高其计算性能,通过删除距离大于合理的最大距离阈值的顶点,将原始的二分图切割成子图后实现了匈牙利算法。”

晚些再写其他文章来分析apollo的实现方式,最近没时间了,先放在这里吧

KM算法步骤

这是KM算法的官网解释:https://brc2.com/the-algorithm-workshop/

步骤0:创建一个大小nxm的矩阵,称为成本矩阵,其中矩阵中每个元素表示将n个工人分配给m个工作所需的成本。旋转矩阵,使其列数至少与行数一样多,并设k=min(n,m)。

步骤1:对于矩阵的每一行,找到最小的元素,同时该行中的每个元素中减去这个最小元素。转至步骤2。

步骤2:在得到的矩阵中找到一个零(Z)。如果行或列中没有标上星号的零,则把这个Z标上星号(以下称:star零)。对矩阵中的每个元素重复此操作。转至步骤3。

第3步:用线覆盖包含一个star零的列。如果覆盖了K列,则这组star零表示一组完整的唯一赋值。在这种情况下,算法“完成”,否则,转到步骤4。

第4步:找到一个未被覆盖的零,让这个零准备好(prime it,以下将这些零称为:prime零)。如果包含prime零的行中没有star零,请转到第5步。否则,覆盖这一行,并揭开包含star零的列。以这种方式继续,直到矩阵中的零全部被覆盖为止。记录最小的未覆盖值,然后转到步骤6。

第5步:按照如下方式构造一系列交替的prime零和star零。设Z0表示在步骤4中找到的未被覆盖的prime零。设Z1表示Z0列中的star零(如果有的话)。设Z2表示Z1行中的prime零(总会有一个)。继续这一过程(找到一个prime零,沿着它所在列找star零,再从找到的star零的所在行,找prime零),找到的是一个由star零和prime零交替组成的序列,该序列会终止于一个prime零,在它的列中没有star零。将这个序列中的star零全部取消星号,将序列的每个prime零标记为star零,将矩阵中所有的prime零变为普通零,并将整个矩阵取消覆盖。返回步骤3。

步骤6:对于步骤4中找到的值,每个被覆盖行的每个元素加上该值,每个未覆盖列的每个元素中减去该值。在加减的过程中,跳过star零、prime零或者覆盖线。返回步骤4。

完成:最终结果由成本矩阵中的star零位置表示。如果C(i,j)是星形零,工人i分配工作j。

如果不清楚的话,可以结合官网的解释来好好再看看。我下面贴一下官网的步骤简图,可以帮助理解每一步操作的含义


C++代码实现

根据官网链接的代码,我写了一版C++代码。(官网的代码虽然写的很清楚,但是在小细节上还是不够完善,比如它并没有提到一些数组越界情况的处理,当然了这是非常小的事情,只不过初学者按照它的写法需要debug罢了)

我的可运行代码贴在下面。

#include <iostream>
#include <vector>
#include <algorithm>

void step_one(int& step, std::vector<std::vector<int>>& cost_matrix, 
              std::vector<std::vector<int>>& mask_matrix) {
    int nrow = cost_matrix.size();
    int ncol = cost_matrix[0].size();
    int min_in_row;
    for (int i = 0; i < nrow; i++) {
        std::vector<int>::iterator min_in_row = 
            std::min_element(cost_matrix[i].begin(), cost_matrix[i].end());
        for (int j = 0; j < ncol; j++) {
             cost_matrix[i][j] -= *min_in_row;
        }
    }
    step = 2;
}

void step_two(int& step, std::vector<std::vector<int>>& cost_matrix, 
              std::vector<std::vector<int>>& mask_matrix,
              std::vector<int>& rowCover, std::vector<int>& colCover) {
    int nrow = cost_matrix.size();
    int ncol = cost_matrix[0].size();
    for (int i = 0; i < nrow; i++) {
        for (int j = 0; j < ncol; j++) {
            if (cost_matrix[i][j] == 0 && rowCover[i] == 0 && colCover[j] == 0) {
                mask_matrix[i][j] = 1;
                rowCover[i] = 1;
                colCover[j] = 1;
            }
        }
    }
    for (int i = 0; i < nrow; i++) {
        rowCover[i] = 0;
    }
    for (int j = 0; j < ncol; j++) {
        colCover[j] = 0;
    }
    step = 3;
}

void step_three(int& step, std::vector<std::vector<int>>& cost_matrix, 
              std::vector<std::vector<int>>& mask_matrix,
              std::vector<int>& rowCover, std::vector<int>& colCover) {
    int nrow = cost_matrix.size();
    int ncol = cost_matrix[0].size();
    for (int i = 0; i < nrow; i++) {
        for (int j = 0; j < ncol; j++) {
            if (mask_matrix[i][j] == 1) {
                colCover[j] = 1;
            }
        }
    }

    int colcount = 0;
    for (int j = 0; j < ncol; j++) {
        if (colCover[j] == 1) {
            colcount += 1;
        }
    }

    if (colcount >= ncol || colcount >= nrow) {
        step = 7;
    } else {
        step = 4;
    }
}

void find_a_zero(int& row, int& col,
                 int& step, std::vector<std::vector<int>>& cost_matrix, 
                   std::vector<std::vector<int>>& mask_matrix,
                 std::vector<int>& rowCover, std::vector<int>& colCover) {
    int i = 0;
    int j;
    bool done = false;
    row = -1;
    col = -1;

    int nrow = cost_matrix.size();
    int ncol = cost_matrix[0].size();
    while (!done) {
        j = 0;
        while (1) {
            if (cost_matrix[i][j] == 0 && rowCover[i] == 0 && colCover[j] == 0) {
                row = i;
                col = j;
                done = true;
            }
            j++;
            if (j >= ncol || done) {
                break;
            }
        }
        i++;
        if (i >= nrow) {
            done = true;
        }
    }
}

bool star_in_row(int& row, int& col, std::vector<std::vector<int>>& mask_matrix) {
    bool tmp = false;
    int ncol = mask_matrix[0].size();
    for (int j = 0; j < ncol; j++) {
        if (mask_matrix[row][j] == 1) {
            tmp = true;
        }
    }
    return tmp;
}

void find_star_in_row(int& row, int& col, std::vector<std::vector<int>>& mask_matrix) {
    col = -1;
    int ncol = mask_matrix[0].size();
    for (int j = 0; j < ncol; j++) {
        if (mask_matrix[row][j] == 1) {
            col = j;
        }
    }
}

void find_star_in_col(int& row, int& col, std::vector<std::vector<int>>& mask_matrix) {
    row = -1;
    int nrow = mask_matrix.size();
    for (int i = 0; i < nrow; i++) {
        if (mask_matrix[i][col] == 1) {
            row = i;
        }
    }
}

void find_prime_in_row(int& row, int& col, std::vector<std::vector<int>>& mask_matrix) {
    int ncol = mask_matrix[0].size();
    for (int j = 0; j < ncol; j++) {
        if (mask_matrix[row][j] == 2) {
            col = j;
        }
    }
}

void augment_path(int& path_count, std::vector<std::vector<int>> path,
    std::vector<std::vector<int>>& mask_matrix){
    for (int p = 0; p < path_count; p++) {
        if (mask_matrix[path[p][0]][path[p][1]] == 1) {
            mask_matrix[path[p][0]][path[p][1]] = 0;
        } else {
            mask_matrix[path[p][0]][path[p][1]] = 1;
        }
    }
}

void clear_covers(std::vector<int>& rowCover, std::vector<int>& colCover) {
    for (int i = 0; i < rowCover.size(); i++) {
        rowCover[i] = 0;
    }
    for (int j = 0; j < colCover.size(); j++) {
        colCover[j] = 0;
    }
}

void erase_primes(std::vector<std::vector<int>>& mask_matrix) {
    int nrow = mask_matrix.size();
    int ncol = mask_matrix[0].size();
    for (int i = 0; i < nrow; i++) {
        for (int j = 0; j < ncol; j++) {
            if (mask_matrix[i][j] == 2) {
                mask_matrix[i][j] = 0;
            }
        }
    }
}


void step_four(int& step, std::vector<std::vector<int>>& cost_matrix, 
              std::vector<std::vector<int>>& mask_matrix,
              std::vector<int>& rowCover, std::vector<int>& colCover, 
              std::vector<std::vector<int>>& path,
              int& path_row_0, int& path_col_0) {
    int row = -1;
    int col = -1;
    bool done = false;

    while (!done) {
        find_a_zero(row, col, step, cost_matrix, mask_matrix, rowCover, colCover);
        if (row == -1) {
            done = true;
            step = 6;
        } else {
            mask_matrix[row][col] = 2;
            if (star_in_row(row, col, mask_matrix)) {
                find_star_in_row(row, col, mask_matrix);
                rowCover[row] = 1;
                colCover[col] = 0;
            } else {
                done = true;
                step = 5;
                path_row_0 = row;
                path_col_0 = col;
            }
        }
    }    
}

void step_five(int& step, std::vector<std::vector<int>>& cost_matrix,
              std::vector<std::vector<int>>& mask_matrix,
              std::vector<int>& rowCover, std::vector<int>& colCover, 
              std::vector<std::vector<int>>& path,
              int& path_row_0, int& path_col_0) {
    bool done = false;
    int i = -1;
    int j = -1;

    int path_count = 1;
    path[path_count-1][0] = path_row_0;
    path[path_count-1][1] = path_col_0;

    while (!done) {
        find_star_in_col(i, path[path_count-1][1], mask_matrix);

        if (i > -1) {

            if (path_count >= path.size()) {
                done = true;
                break;
            }
            path_count += 1;
            path[path_count-1][0] = i;
            path[path_count-1][1] = path[path_count-2][1];
        } else {

            done = true;
        }

        if (!done) {
            find_prime_in_row(path[path_count-1][0], j, mask_matrix);

            if (path_count >= path.size()) {
                done = true;
                break;
            }

            path_count += 1;

            path[path_count-1][0] = path[path_count-2][0];
            path[path_count-1][1] = j;
        }
    }

    augment_path(path_count, path, mask_matrix);
    clear_covers(rowCover, colCover);
    erase_primes(mask_matrix);
    step = 3;
}

void find_smallest(int& minval, std::vector<std::vector<int>>& cost_matrix, 
                     std::vector<int>& rowCover, std::vector<int>& colCover) {
    for (int i = 0; i < rowCover.size(); i++) {
        for (int j = 0; j < colCover.size(); j++) { 
            if (rowCover[i] == 0 && colCover[j] == 0) {
                if (minval > cost_matrix[i][j]) {
                    minval = cost_matrix[i][j];
                }
            }
        }
    }
}

void step_six(int& step, std::vector<std::vector<int>>& cost_matrix, 
            std::vector<int>& rowCover, std::vector<int>& colCover){
    int minval = 0x7fffffff;
    find_smallest(minval, cost_matrix, rowCover, colCover);

    for (int i = 0; i < rowCover.size(); i++) {
        for (int j = 0; j < colCover.size(); j++) {
            if (rowCover[i] == 1) {
                cost_matrix[i][j] += minval;
            } 
            if (colCover[j] == 0) {
                cost_matrix[i][j] -= minval;
            }
        }
    }
    step = 4;
}

void step_seven() {
    std::cout<<"Munkres Assignment Algorithm finish!!!"<<std::endl;
}

void process_step(bool& done, int& step, std::vector<std::vector<int>> cost_matrix_origin, 
              std::vector<std::vector<int>>& cost_matrix, 
              std::vector<std::vector<int>>& mask_matrix,
              std::vector<int>& rowCover, std::vector<int>& colCover, 
              std::vector<std::vector<int>>& path,
              int& path_row_0, int& path_col_0) {
    std::cout<<"excute step: "<<step<<std::endl;
    switch (step) {
        case 1:
            step_one(step, cost_matrix, mask_matrix);
            break;
        case 2:
            step_two(step, cost_matrix, mask_matrix, rowCover, colCover);
            break;
        case 3:
            step_three(step, cost_matrix, mask_matrix, rowCover, colCover);
            break;
        case 4:
            step_four(step, cost_matrix, mask_matrix, rowCover, colCover, 
                    path, path_row_0, path_col_0);
            break;
        case 5:
            step_five(step, cost_matrix, mask_matrix, rowCover, colCover, 
                    path, path_row_0, path_col_0);
            break;
        case 6:
            step_six(step, cost_matrix, rowCover, colCover);
            break;
        case 7:
            step_seven();
            done = true;

            int mini_cost_value = 0;
            for (int i = 0; i < mask_matrix.size(); i++) {
                for (int j = 0; j < mask_matrix[i].size(); j++) {
                    if (mask_matrix[i][j] == 1) {
                        std::cout<<"use data on ("<<i<<", "<<j<<") value: "<<cost_matrix_origin[i][j]<<std::endl;
                        mini_cost_value += cost_matrix_origin[i][j];
                    }
                }
            }            
            std::cout<<"mini cost:" <<mini_cost_value<<std::endl;
            break;
    }
}

void showCostMatrix(std::vector<std::vector<int>>& cost_matrix) {
    std::cout<<"__________cost_________"<<std::endl;
    for (int i = 0; i < cost_matrix.size(); i++) {
        for (int j = 0; j < cost_matrix[i].size(); j++) {
            std::cout<<cost_matrix[i][j]<<" ";
        }
        std::cout<<std::endl;
    }
}

void showMaskMatrix(std::vector<std::vector<int>>& mask_matrix) {
    std::cout<<"__________mask_________"<<std::endl;
    for (int i = 0; i < mask_matrix.size(); i++) {
        for (int j = 0; j < mask_matrix[i].size(); j++) {
            std::cout<<mask_matrix[i][j]<<" ";
        }
        std::cout<<std::endl;
    }
}

void munkresAssignment(std::vector<std::vector<int>>& cost_matrix) {

    std::vector<std::vector<int>> cost_matrix_origin = cost_matrix;

    std::vector<std::vector<int>> mask_matrix;
    for (int i = 0; i < cost_matrix.size(); i++) {
        std::vector<int> a;
        for (int j = 0; j < cost_matrix[i].size(); j++) {
            a.push_back(0);
        }
        mask_matrix.push_back(a);
    }

    std::vector<int> rowCover;
    for (int i = 0; i < cost_matrix.size(); i++) {
        rowCover.push_back(0);
    }
    std::vector<int> colCover;
    for (int i = 0; i < cost_matrix[0].size(); i++) {
        colCover.push_back(0);
    }

    std::vector<std::vector<int>> path;
    for (int i = 0; i < cost_matrix.size(); i++) {
        std::vector<int> a;
        a.push_back(0);
        a.push_back(0);
        path.push_back(a);
    }

    bool done = false;
    int step = 1;
    int path_row_0, path_col_0;
    while (!done) {
        showCostMatrix(cost_matrix);
        showMaskMatrix(mask_matrix);

        process_step(done, step, cost_matrix_origin, cost_matrix, mask_matrix, rowCover, colCover, 
                    path, path_row_0, path_col_0);
    }
}

int main(int argc, char** argv) {
    int value[16] = {1, 2, 3, 4,
                     2, 4, 6, 8,
                     3, 6, 9, 12,
                     4, 8, 12,16};

    std::vector<std::vector<int>> cost_matrix;

    for (int i = 0; i < 4; i++) {
        std::vector<int> tmp;
        for (int j = 0; j < 4; j++) {
            tmp.push_back(value[i*4+j]);
        }
        cost_matrix.push_back(tmp);
    }

    munkresAssignment(cost_matrix);

    return 0;
}

如果你执行我代码中的数据

int value[16] = {1, 2, 3, 4,
                     2, 4, 6, 8,
                     3, 6, 9, 12,
                     4, 8, 12,16};

在输出的最后,可以得到这样的结果

...
...
__________mask_________
0 0 0 1 
0 0 1 0 
0 1 0 0 
1 0 0 0 
excute step: 7
Munkres Assignment Algorithm finish!!!
use data on (0, 3) value: 4
use data on (1, 2) value: 6
use data on (2, 1) value: 6
use data on (3, 0) value: 4
mini cost:20

可以很明显的看出mask矩阵为1的部分,就是工人和工作的最佳匹配了,这次最优匹配可以达到的最小值就是20。

当然了,你也可以自行修改value矩阵的赋值,来验证我写的是否正确。(时间紧急,我也就验证了两三组数据,如果代码执行有问题随时告诉我哈)

KM算法的缺点

上面已经详细介绍了KM算法,相信已经介绍的很清楚了。他是一种用于解决优化中的分配问题的常用算法。

然而,和大多数算法一样,它也有自己的一系列局限性或缺点。缺点如下:

1.KM算法的时间复杂度为O(n^3),这使得它对于大型数据集来说速度较慢。

2.KM算法需要一个平方代价矩阵,有的问题可能无法转换为该形式。

3.KM算法会得到一个解,但是如果当前匹配存在多个最优解,就不能用它来确定哪个解是最好的了。

总结

这两篇就是对匈牙利匹配算法及KM算法的分析与总结,后续有时间看看apollo是怎么实现的。