31 namespace perception {
40 global_costs_.Reserve(max_matching_size, max_matching_size);
41 optimizer_.costs()->Reserve(max_matching_size, max_matching_size);
56 std::vector<std::pair<size_t, size_t>>* assignments,
57 std::vector<size_t>* unassigned_rows,
58 std::vector<size_t>* unassigned_cols);
61 std::vector<std::pair<size_t, size_t>>* assignments,
62 std::vector<size_t>* unassigned_rows,
63 std::vector<size_t>* unassigned_cols);
74 void ComputeConnectedComponents(
75 std::vector<std::vector<size_t>>* row_components,
76 std::vector<std::vector<size_t>>* col_components)
const;
80 void OptimizeConnectedComponent(
const std::vector<size_t>& row_component,
81 const std::vector<size_t>& col_component);
85 void GenerateUnassignedData(std::vector<size_t>* unassigned_rows,
86 std::vector<size_t>* unassigned_cols)
const;
94 void UpdateGatingLocalCostsMat(
const std::vector<size_t>& row_component,
95 const std::vector<size_t>& col_component);
98 std::vector<std::pair<size_t, size_t>>* local_assignments);
107 T cost_thresh_ = 0.0;
108 T bound_value_ = 0.0;
112 mutable std::vector<std::pair<size_t, size_t>>* assignments_ptr_ =
nullptr;
115 size_t rows_num_ = 0;
116 size_t cols_num_ = 0;
119 std::function<bool(T, T)> compare_fun_;
120 std::function<bool(T)> is_valid_cost_;
123 template <
typename T>
126 std::vector<std::pair<size_t, size_t>>* assignments,
127 std::vector<size_t>* unassigned_rows,
128 std::vector<size_t>* unassigned_cols) {
129 Match(cost_thresh, cost_thresh, opt_flag, assignments, unassigned_rows,
133 template <
typename T>
136 std::vector<std::pair<size_t, size_t>>* assignments,
137 std::vector<size_t>* unassigned_rows,
138 std::vector<size_t>* unassigned_cols) {
139 CHECK_NOTNULL(assignments);
140 CHECK_NOTNULL(unassigned_rows);
141 CHECK_NOTNULL(unassigned_cols);
144 cost_thresh_ = cost_thresh;
145 opt_flag_ = opt_flag;
146 bound_value_ = bound_value;
147 assignments_ptr_ = assignments;
151 std::vector<std::vector<size_t>> row_components;
152 std::vector<std::vector<size_t>> col_components;
153 this->ComputeConnectedComponents(&row_components, &col_components);
154 CHECK_EQ(row_components.size(), col_components.size());
157 assignments_ptr_->clear();
158 assignments_ptr_->reserve(std::max(rows_num_, cols_num_));
159 for (
size_t i = 0; i < row_components.size(); ++i) {
160 this->OptimizeConnectedComponent(row_components[i], col_components[i]);
163 this->GenerateUnassignedData(unassigned_rows, unassigned_cols);
166 template <
typename T>
169 rows_num_ = global_costs_.height();
170 cols_num_ = (rows_num_ == 0) ? 0 : global_costs_.width();
173 static std::map<OptimizeFlag, std::function<bool(T, T)>> compare_fun_map = {
177 auto find_ret = compare_fun_map.find(opt_flag_);
178 ACHECK(find_ret != compare_fun_map.end());
179 compare_fun_ = find_ret->second;
180 is_valid_cost_ = std::bind1st(compare_fun_, cost_thresh_);
183 ACHECK(!is_valid_cost_(bound_value_));
186 template <
typename T>
188 std::vector<std::vector<size_t>>* row_components,
189 std::vector<std::vector<size_t>>* col_components)
const {
190 CHECK_NOTNULL(row_components);
191 CHECK_NOTNULL(col_components);
193 std::vector<std::vector<int>> nb_graph;
194 nb_graph.resize(rows_num_ + cols_num_);
195 for (
size_t i = 0; i < rows_num_; ++i) {
196 for (
size_t j = 0; j < cols_num_; ++j) {
197 if (is_valid_cost_(global_costs_(i, j))) {
198 nb_graph[i].push_back(static_cast<int>(rows_num_) + j);
199 nb_graph[j + rows_num_].push_back(i);
204 std::vector<std::vector<int>> components;
206 row_components->clear();
207 row_components->resize(components.size());
208 col_components->clear();
209 col_components->resize(components.size());
210 for (
size_t i = 0; i < components.size(); ++i) {
211 for (
size_t j = 0; j < components[i].size(); ++j) {
212 int id = components[i][j];
213 if (
id < static_cast<int>(rows_num_)) {
214 row_components->at(i).push_back(
id);
216 id -=
static_cast<int>(rows_num_);
217 col_components->at(i).push_back(
id);
223 template <
typename T>
225 const std::vector<size_t>& row_component,
226 const std::vector<size_t>& col_component) {
227 size_t local_rows_num = row_component.size();
228 size_t local_cols_num = col_component.size();
231 if (!local_rows_num || !local_cols_num) {
235 if (local_rows_num == 1 && local_cols_num == 1) {
236 size_t idx_r = row_component[0];
237 size_t idx_c = col_component[0];
238 if (is_valid_cost_(global_costs_(idx_r, idx_c))) {
239 assignments_ptr_->push_back(std::make_pair(idx_r, idx_c));
245 UpdateGatingLocalCostsMat(row_component, col_component);
248 std::vector<std::pair<size_t, size_t>> local_assignments;
249 OptimizeAdapter(&local_assignments);
252 for (
size_t i = 0; i < local_assignments.size(); ++i) {
253 auto local_assignment = local_assignments[i];
254 size_t global_row_idx = row_component[local_assignment.first];
255 size_t global_col_idx = col_component[local_assignment.second];
256 if (!is_valid_cost_(global_costs_(global_row_idx, global_col_idx))) {
259 assignments_ptr_->push_back(std::make_pair(global_row_idx, global_col_idx));
263 template <
typename T>
265 std::vector<size_t>* unassigned_rows,
266 std::vector<size_t>* unassigned_cols)
const {
267 CHECK_NOTNULL(unassigned_rows);
268 CHECK_NOTNULL(unassigned_cols);
270 const auto assignments = *assignments_ptr_;
271 unassigned_rows->clear(), unassigned_rows->reserve(rows_num_);
272 unassigned_cols->clear(), unassigned_cols->reserve(cols_num_);
273 std::vector<bool> row_assignment_flags(rows_num_,
false);
274 std::vector<bool> col_assignment_flags(cols_num_,
false);
275 for (
const auto& assignment : assignments) {
276 row_assignment_flags[assignment.first] =
true;
277 col_assignment_flags[assignment.second] =
true;
279 for (
size_t i = 0; i < row_assignment_flags.size(); ++i) {
280 if (!row_assignment_flags[i]) {
281 unassigned_rows->push_back(i);
284 for (
size_t i = 0; i < col_assignment_flags.size(); ++i) {
285 if (!col_assignment_flags[i]) {
286 unassigned_cols->push_back(i);
291 template <
typename T>
293 const std::vector<size_t>& row_component,
294 const std::vector<size_t>& col_component) {
297 local_costs->
Resize(row_component.size(), col_component.size());
298 for (
size_t i = 0; i < row_component.size(); ++i) {
299 for (
size_t j = 0; j < col_component.size(); ++j) {
300 T& current_cost = global_costs_(row_component[i], col_component[j]);
301 if (is_valid_cost_(current_cost)) {
302 (*local_costs)(i, j) = current_cost;
304 (*local_costs)(i, j) = bound_value_;
310 template <
typename T>
312 std::vector<std::pair<size_t, size_t>>* local_assignments) {
313 CHECK_NOTNULL(local_assignments);
315 optimizer_.Maximize(local_assignments);
317 optimizer_.Minimize(local_assignments);
Definition: hungarian_optimizer.h:33
void ConnectedComponentAnalysis(const std::vector< std::vector< int >> &graph, std::vector< std::vector< int >> *components)
Definition: secure_matrix.h:29
#define ACHECK(cond)
Definition: log.h:80
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
Definition: gated_hungarian_bigraph_matcher.h:35
const SecureMat< T > & global_costs() const
Definition: gated_hungarian_bigraph_matcher.h:52
SecureMat< T > * mutable_global_costs()
Definition: gated_hungarian_bigraph_matcher.h:53
void Match(T cost_thresh, OptimizeFlag opt_flag, std::vector< std::pair< size_t, size_t >> *assignments, std::vector< size_t > *unassigned_rows, std::vector< size_t > *unassigned_cols)
Definition: gated_hungarian_bigraph_matcher.h:124
OptimizeFlag
Definition: gated_hungarian_bigraph_matcher.h:37
GatedHungarianMatcher(int max_matching_size=1000)
Definition: gated_hungarian_bigraph_matcher.h:39
~GatedHungarianMatcher()
Definition: gated_hungarian_bigraph_matcher.h:43
void Resize(const size_t resize_height, const size_t resize_width)
Definition: secure_matrix.h:53