Apollo  6.0
Open source self driving car software
gated_hungarian_bigraph_matcher.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2018 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
17 #pragma once
18 
19 #include <algorithm>
20 #include <functional>
21 #include <map>
22 #include <utility>
23 #include <vector>
24 
25 #include "cyber/common/log.h"
26 
29 
30 namespace apollo {
31 namespace perception {
32 namespace common {
33 
34 template <typename T>
36  public:
37  enum class OptimizeFlag { OPTMAX, OPTMIN };
38 
39  explicit GatedHungarianMatcher(int max_matching_size = 1000) {
40  global_costs_.Reserve(max_matching_size, max_matching_size);
41  optimizer_.costs()->Reserve(max_matching_size, max_matching_size);
42  }
44 
45  /* @brief: global_costs is the memory we reserved for the updating of
46  * costs of matching. it could & need be updated outside the matcher,
47  * before each matching. use it carefully, and make sure all the
48  * elements of the global_costs is updated as you presumed. resize it
49  * and update it completely is STRONG RECOMMENDED!! P.S. resizing SecureMat
50  * would not alloc new memory, if the resizing size is smaller than the
51  * size reserved. */
52  const SecureMat<T>& global_costs() const { return global_costs_; }
53  SecureMat<T>* mutable_global_costs() { return &global_costs_; }
54 
55  void Match(T cost_thresh, OptimizeFlag opt_flag,
56  std::vector<std::pair<size_t, size_t>>* assignments,
57  std::vector<size_t>* unassigned_rows,
58  std::vector<size_t>* unassigned_cols);
59 
60  void Match(T cost_thresh, T bound_value, OptimizeFlag opt_flag,
61  std::vector<std::pair<size_t, size_t>>* assignments,
62  std::vector<size_t>* unassigned_rows,
63  std::vector<size_t>* unassigned_cols);
64 
65  private:
66  /* Step 1:
67  * a. get number of rows & cols
68  * b. determine function of comparison */
69  void MatchInit();
70 
71  /* Step 2:
72  * to acclerate matching process, split input cost graph into several
73  * small sub-parts. */
74  void ComputeConnectedComponents(
75  std::vector<std::vector<size_t>>* row_components,
76  std::vector<std::vector<size_t>>* col_components) const;
77 
78  /* Step 3:
79  * optimize single connected component, which is part of the global one */
80  void OptimizeConnectedComponent(const std::vector<size_t>& row_component,
81  const std::vector<size_t>& col_component);
82 
83  /* Step 4:
84  * generate the set of unassigned row or col index. */
85  void GenerateUnassignedData(std::vector<size_t>* unassigned_rows,
86  std::vector<size_t>* unassigned_cols) const;
87 
88  /* @brief: core function for updating the local cost matrix from global one,
89  * we get queryed local costs and write them in the memeory of costs of
90  * optimizer directly
91  * @params[IN] row_component: the set of index of rows of sub-graph
92  * @params[IN] col_component: the set of index of cols of sub-graph
93  * @return: nothing */
94  void UpdateGatingLocalCostsMat(const std::vector<size_t>& row_component,
95  const std::vector<size_t>& col_component);
96 
97  void OptimizeAdapter(
98  std::vector<std::pair<size_t, size_t>>* local_assignments);
99 
100  /* Hungarian optimizer */
101  HungarianOptimizer<T> optimizer_;
102 
103  /* global costs matrix */
104  SecureMat<T> global_costs_;
105 
106  /* input data */
107  T cost_thresh_ = 0.0;
108  T bound_value_ = 0.0;
110 
111  /* output data */
112  mutable std::vector<std::pair<size_t, size_t>>* assignments_ptr_ = nullptr;
113 
114  /* size of component */
115  size_t rows_num_ = 0;
116  size_t cols_num_ = 0;
117 
118  /* the rhs is always better than lhs */
119  std::function<bool(T, T)> compare_fun_;
120  std::function<bool(T)> is_valid_cost_;
121 }; // class GatedHungarianMatcher
122 
123 template <typename T>
125  T cost_thresh, OptimizeFlag opt_flag,
126  std::vector<std::pair<size_t, size_t>>* assignments,
127  std::vector<size_t>* unassigned_rows,
128  std::vector<size_t>* unassigned_cols) {
129  Match(cost_thresh, cost_thresh, opt_flag, assignments, unassigned_rows,
130  unassigned_cols);
131 }
132 
133 template <typename T>
135  T cost_thresh, T bound_value, OptimizeFlag opt_flag,
136  std::vector<std::pair<size_t, size_t>>* assignments,
137  std::vector<size_t>* unassigned_rows,
138  std::vector<size_t>* unassigned_cols) {
139  CHECK_NOTNULL(assignments);
140  CHECK_NOTNULL(unassigned_rows);
141  CHECK_NOTNULL(unassigned_cols);
142 
143  /* initialize matcher */
144  cost_thresh_ = cost_thresh;
145  opt_flag_ = opt_flag;
146  bound_value_ = bound_value;
147  assignments_ptr_ = assignments;
148  MatchInit();
149 
150  /* compute components */
151  std::vector<std::vector<size_t>> row_components;
152  std::vector<std::vector<size_t>> col_components;
153  this->ComputeConnectedComponents(&row_components, &col_components);
154  CHECK_EQ(row_components.size(), col_components.size());
155 
156  /* compute assignments */
157  assignments_ptr_->clear();
158  assignments_ptr_->reserve(std::max(rows_num_, cols_num_));
159  for (size_t i = 0; i < row_components.size(); ++i) {
160  this->OptimizeConnectedComponent(row_components[i], col_components[i]);
161  }
162 
163  this->GenerateUnassignedData(unassigned_rows, unassigned_cols);
164 }
165 
166 template <typename T>
168  /* get number of rows & cols */
169  rows_num_ = global_costs_.height();
170  cols_num_ = (rows_num_ == 0) ? 0 : global_costs_.width();
171 
172  /* determine function of comparison */
173  static std::map<OptimizeFlag, std::function<bool(T, T)>> compare_fun_map = {
174  {OptimizeFlag::OPTMAX, std::less<T>()},
175  {OptimizeFlag::OPTMIN, std::greater<T>()},
176  };
177  auto find_ret = compare_fun_map.find(opt_flag_);
178  ACHECK(find_ret != compare_fun_map.end());
179  compare_fun_ = find_ret->second;
180  is_valid_cost_ = std::bind1st(compare_fun_, cost_thresh_);
181 
182  /* check the validity of bound_value */
183  ACHECK(!is_valid_cost_(bound_value_));
184 }
185 
186 template <typename T>
188  std::vector<std::vector<size_t>>* row_components,
189  std::vector<std::vector<size_t>>* col_components) const {
190  CHECK_NOTNULL(row_components);
191  CHECK_NOTNULL(col_components);
192 
193  std::vector<std::vector<int>> nb_graph;
194  nb_graph.resize(rows_num_ + cols_num_);
195  for (size_t i = 0; i < rows_num_; ++i) {
196  for (size_t j = 0; j < cols_num_; ++j) {
197  if (is_valid_cost_(global_costs_(i, j))) {
198  nb_graph[i].push_back(static_cast<int>(rows_num_) + j);
199  nb_graph[j + rows_num_].push_back(i);
200  }
201  }
202  }
203 
204  std::vector<std::vector<int>> components;
205  ConnectedComponentAnalysis(nb_graph, &components);
206  row_components->clear();
207  row_components->resize(components.size());
208  col_components->clear();
209  col_components->resize(components.size());
210  for (size_t i = 0; i < components.size(); ++i) {
211  for (size_t j = 0; j < components[i].size(); ++j) {
212  int id = components[i][j];
213  if (id < static_cast<int>(rows_num_)) {
214  row_components->at(i).push_back(id);
215  } else {
216  id -= static_cast<int>(rows_num_);
217  col_components->at(i).push_back(id);
218  }
219  }
220  }
221 }
222 
223 template <typename T>
225  const std::vector<size_t>& row_component,
226  const std::vector<size_t>& col_component) {
227  size_t local_rows_num = row_component.size();
228  size_t local_cols_num = col_component.size();
229 
230  /* simple case 1: no possible matches */
231  if (!local_rows_num || !local_cols_num) {
232  return;
233  }
234  /* simple case 2: 1v1 pair with no ambiguousness */
235  if (local_rows_num == 1 && local_cols_num == 1) {
236  size_t idx_r = row_component[0];
237  size_t idx_c = col_component[0];
238  if (is_valid_cost_(global_costs_(idx_r, idx_c))) {
239  assignments_ptr_->push_back(std::make_pair(idx_r, idx_c));
240  }
241  return;
242  }
243 
244  /* update local cost matrix */
245  UpdateGatingLocalCostsMat(row_component, col_component);
246 
247  /* get local assignments */
248  std::vector<std::pair<size_t, size_t>> local_assignments;
249  OptimizeAdapter(&local_assignments);
250 
251  /* parse local assginments into global ones */
252  for (size_t i = 0; i < local_assignments.size(); ++i) {
253  auto local_assignment = local_assignments[i];
254  size_t global_row_idx = row_component[local_assignment.first];
255  size_t global_col_idx = col_component[local_assignment.second];
256  if (!is_valid_cost_(global_costs_(global_row_idx, global_col_idx))) {
257  continue;
258  }
259  assignments_ptr_->push_back(std::make_pair(global_row_idx, global_col_idx));
260  }
261 }
262 
263 template <typename T>
265  std::vector<size_t>* unassigned_rows,
266  std::vector<size_t>* unassigned_cols) const {
267  CHECK_NOTNULL(unassigned_rows);
268  CHECK_NOTNULL(unassigned_cols);
269 
270  const auto assignments = *assignments_ptr_;
271  unassigned_rows->clear(), unassigned_rows->reserve(rows_num_);
272  unassigned_cols->clear(), unassigned_cols->reserve(cols_num_);
273  std::vector<bool> row_assignment_flags(rows_num_, false);
274  std::vector<bool> col_assignment_flags(cols_num_, false);
275  for (const auto& assignment : assignments) {
276  row_assignment_flags[assignment.first] = true;
277  col_assignment_flags[assignment.second] = true;
278  }
279  for (size_t i = 0; i < row_assignment_flags.size(); ++i) {
280  if (!row_assignment_flags[i]) {
281  unassigned_rows->push_back(i);
282  }
283  }
284  for (size_t i = 0; i < col_assignment_flags.size(); ++i) {
285  if (!col_assignment_flags[i]) {
286  unassigned_cols->push_back(i);
287  }
288  }
289 }
290 
291 template <typename T>
293  const std::vector<size_t>& row_component,
294  const std::vector<size_t>& col_component) {
295  /* set the invalid cost to bound value */
296  SecureMat<T>* local_costs = optimizer_.costs();
297  local_costs->Resize(row_component.size(), col_component.size());
298  for (size_t i = 0; i < row_component.size(); ++i) {
299  for (size_t j = 0; j < col_component.size(); ++j) {
300  T& current_cost = global_costs_(row_component[i], col_component[j]);
301  if (is_valid_cost_(current_cost)) {
302  (*local_costs)(i, j) = current_cost;
303  } else {
304  (*local_costs)(i, j) = bound_value_;
305  }
306  }
307  }
308 }
309 
310 template <typename T>
312  std::vector<std::pair<size_t, size_t>>* local_assignments) {
313  CHECK_NOTNULL(local_assignments);
314  if (opt_flag_ == OptimizeFlag::OPTMAX) {
315  optimizer_.Maximize(local_assignments);
316  } else {
317  optimizer_.Minimize(local_assignments);
318  }
319 }
320 
321 } // namespace common
322 } // namespace perception
323 } // namespace apollo
Definition: hungarian_optimizer.h:33
void ConnectedComponentAnalysis(const std::vector< std::vector< int >> &graph, std::vector< std::vector< int >> *components)
Definition: secure_matrix.h:29
#define ACHECK(cond)
Definition: log.h:80
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
Definition: gated_hungarian_bigraph_matcher.h:35
const SecureMat< T > & global_costs() const
Definition: gated_hungarian_bigraph_matcher.h:52
SecureMat< T > * mutable_global_costs()
Definition: gated_hungarian_bigraph_matcher.h:53
void Match(T cost_thresh, OptimizeFlag opt_flag, std::vector< std::pair< size_t, size_t >> *assignments, std::vector< size_t > *unassigned_rows, std::vector< size_t > *unassigned_cols)
Definition: gated_hungarian_bigraph_matcher.h:124
OptimizeFlag
Definition: gated_hungarian_bigraph_matcher.h:37
GatedHungarianMatcher(int max_matching_size=1000)
Definition: gated_hungarian_bigraph_matcher.h:39
~GatedHungarianMatcher()
Definition: gated_hungarian_bigraph_matcher.h:43
void Resize(const size_t resize_height, const size_t resize_width)
Definition: secure_matrix.h:53