MinCostMatching.cc 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. ///////////////////////////////////////////////////////////////////////////
  2. // Min cost bipartite matching via shortest augmenting paths
  3. //
  4. // This is an O(n^3) implementation of a shortest augmenting path
  5. // algorithm for finding min cost perfect matchings in dense
  6. // graphs. In practice, it solves 1000x1000 problems in around 1
  7. // second.
  8. //
  9. // cost[i][j] = cost for pairing left node i with right node j
  10. // Lmate[i] = index of right node that left node i pairs with
  11. // Rmate[j] = index of left node that right node j pairs with
  12. //
  13. // The values in cost[i][j] may be positive or negative. To perform
  14. // maximization, simply negate the cost[][] matrix.
  15. ///////////////////////////////////////////////////////////////////////////
  16. #include <algorithm>
  17. #include <cstdio>
  18. #include <cmath>
  19. #include <vector>
  20. using namespace std;
  21. typedef vector<double> VD;
  22. typedef vector<VD> VVD;
  23. typedef vector<int> VI;
  24. double MinCostMatching(const VVD &cost, VI &Lmate, VI &Rmate) {
  25. int n = int(cost.size());
  26. // construct dual feasible solution
  27. VD u(n);
  28. VD v(n);
  29. for (int i = 0; i < n; i++) {
  30. u[i] = cost[i][0];
  31. for (int j = 1; j < n; j++) u[i] = min(u[i], cost[i][j]);
  32. }
  33. for (int j = 0; j < n; j++) {
  34. v[j] = cost[0][j] - u[0];
  35. for (int i = 1; i < n; i++) v[j] = min(v[j], cost[i][j] - u[i]);
  36. }
  37. // construct primal solution satisfying complementary slackness
  38. Lmate = VI(n, -1);
  39. Rmate = VI(n, -1);
  40. int mated = 0;
  41. for (int i = 0; i < n; i++) {
  42. for (int j = 0; j < n; j++) {
  43. if (Rmate[j] != -1) continue;
  44. if (fabs(cost[i][j] - u[i] - v[j]) < 1e-10) {
  45. Lmate[i] = j;
  46. Rmate[j] = i;
  47. mated++;
  48. break;
  49. }
  50. }
  51. }
  52. VD dist(n);
  53. VI dad(n);
  54. VI seen(n);
  55. // repeat until primal solution is feasible
  56. while (mated < n) {
  57. // find an unmatched left node
  58. int s = 0;
  59. while (Lmate[s] != -1) s++;
  60. // initialize Dijkstra
  61. fill(dad.begin(), dad.end(), -1);
  62. fill(seen.begin(), seen.end(), 0);
  63. for (int k = 0; k < n; k++)
  64. dist[k] = cost[s][k] - u[s] - v[k];
  65. int j = 0;
  66. while (true) {
  67. // find closest
  68. j = -1;
  69. for (int k = 0; k < n; k++) {
  70. if (seen[k]) continue;
  71. if (j == -1 || dist[k] < dist[j]) j = k;
  72. }
  73. seen[j] = 1;
  74. // termination condition
  75. if (Rmate[j] == -1) break;
  76. // relax neighbors
  77. const int i = Rmate[j];
  78. for (int k = 0; k < n; k++) {
  79. if (seen[k]) continue;
  80. const double new_dist = dist[j] + cost[i][k] - u[i] - v[k];
  81. if (dist[k] > new_dist) {
  82. dist[k] = new_dist;
  83. dad[k] = j;
  84. }
  85. }
  86. }
  87. // update dual variables
  88. for (int k = 0; k < n; k++) {
  89. if (k == j || !seen[k]) continue;
  90. const int i = Rmate[k];
  91. v[k] += dist[k] - dist[j];
  92. u[i] -= dist[k] - dist[j];
  93. }
  94. u[s] += dist[j];
  95. // augment along path
  96. while (dad[j] >= 0) {
  97. const int d = dad[j];
  98. Rmate[j] = Rmate[d];
  99. Lmate[Rmate[j]] = j;
  100. j = d;
  101. }
  102. Rmate[j] = s;
  103. Lmate[s] = j;
  104. mated++;
  105. }
  106. double value = 0;
  107. for (int i = 0; i < n; i++)
  108. value += cost[i][Lmate[i]];
  109. return value;
  110. }