forked from TheAlgorithms/Java
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathHungarianAlgorithm.java
More file actions
150 lines (138 loc) · 4.92 KB
/
HungarianAlgorithm.java
File metadata and controls
150 lines (138 loc) · 4.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package com.thealgorithms.graph;
import java.util.Arrays;
/**
* Hungarian algorithm (a.k.a. Kuhn–Munkres) for the Assignment Problem.
*
* <p>Given an n x m cost matrix (n tasks, m workers), finds a minimum-cost
* one-to-one assignment. If the matrix is rectangular, the algorithm pads to a
* square internally. Costs must be finite non-negative integers.
*
* <p>Time complexity: O(n^3) with n = max(rows, cols).
*
* <p>API returns the assignment as an array where {@code assignment[i]} is the
* column chosen for row i (or -1 if unassigned when rows != cols), and a total
* minimal cost.
*
* @see <a href="https://en.wikipedia.org/wiki/Hungarian_algorithm">Wikipedia: Hungarian algorithm</a>
*/
public final class HungarianAlgorithm {
private HungarianAlgorithm() {
}
/** Result holder for the Hungarian algorithm. */
public static final class Result {
public final int[] assignment; // assignment[row] = col or -1
public final int minCost;
public Result(int[] assignment, int minCost) {
this.assignment = assignment;
this.minCost = minCost;
}
}
/**
* Solves the assignment problem for a non-negative cost matrix.
*
* @param cost an r x c matrix of non-negative costs
* @return Result with row-to-column assignment and minimal total cost
* @throws IllegalArgumentException for null/empty or negative costs
*/
public static Result solve(int[][] cost) {
validate(cost);
int rows = cost.length;
int cols = cost[0].length;
int n = Math.max(rows, cols);
// Build square matrix with padding 0 for missing cells
int[][] a = new int[n][n];
for (int i = 0; i < n; i++) {
if (i < rows) {
for (int j = 0; j < n; j++) {
a[i][j] = (j < cols) ? cost[i][j] : 0;
}
} else {
Arrays.fill(a[i], 0);
}
}
// Potentials and matching arrays
int[] u = new int[n + 1];
int[] v = new int[n + 1];
int[] p = new int[n + 1];
int[] way = new int[n + 1];
for (int i = 1; i <= n; i++) {
p[0] = i;
int j0 = 0;
int[] minv = new int[n + 1];
boolean[] used = new boolean[n + 1];
Arrays.fill(minv, Integer.MAX_VALUE);
Arrays.fill(used, false);
do {
used[j0] = true;
int i0 = p[j0];
int delta = Integer.MAX_VALUE;
int j1 = 0;
for (int j = 1; j <= n; j++) {
if (!used[j]) {
int cur = a[i0 - 1][j - 1] - u[i0] - v[j];
if (cur < minv[j]) {
minv[j] = cur;
way[j] = j0;
}
if (minv[j] < delta) {
delta = minv[j];
j1 = j;
}
}
}
for (int j = 0; j <= n; j++) {
if (used[j]) {
u[p[j]] += delta;
v[j] -= delta;
} else {
minv[j] -= delta;
}
}
j0 = j1;
} while (p[j0] != 0);
do {
int j1 = way[j0];
p[j0] = p[j1];
j0 = j1;
} while (j0 != 0);
}
int[] matchColForRow = new int[n];
Arrays.fill(matchColForRow, -1);
for (int j = 1; j <= n; j++) {
if (p[j] != 0) {
matchColForRow[p[j] - 1] = j - 1;
}
}
// Build assignment for original rows only, ignore padded rows
int[] assignment = new int[rows];
Arrays.fill(assignment, -1);
int total = 0;
for (int i = 0; i < rows; i++) {
int j = matchColForRow[i];
if (j >= 0 && j < cols) {
assignment[i] = j;
total += cost[i][j];
}
}
return new Result(assignment, total);
}
private static void validate(int[][] cost) {
if (cost == null || cost.length == 0) {
throw new IllegalArgumentException("Cost matrix must not be null or empty");
}
int c = cost[0].length;
if (c == 0) {
throw new IllegalArgumentException("Cost matrix must have at least 1 column");
}
for (int i = 0; i < cost.length; i++) {
if (cost[i] == null || cost[i].length != c) {
throw new IllegalArgumentException("Cost matrix must be rectangular with equal row lengths");
}
for (int j = 0; j < c; j++) {
if (cost[i][j] < 0) {
throw new IllegalArgumentException("Costs must be non-negative");
}
}
}
}
}