Apollo  6.0
Open source self driving car software
argmax_plugin.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2018 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
17 #pragma once
18 
19 #include <limits>
20 
21 #include "modules/perception/proto/rt.pb.h"
22 
24 
25 namespace apollo {
26 namespace perception {
27 namespace inference {
28 
29 class ArgMax1Plugin : public nvinfer1::IPlugin {
30  public:
31  ArgMax1Plugin(const ArgMaxParameter &argmax_param, nvinfer1::Dims in_dims)
32  : float_min_(std::numeric_limits<float>::min()) {
33  input_dims_.nbDims = in_dims.nbDims;
34  CHECK_GT(input_dims_.nbDims, 0);
35  for (int i = 0; i < in_dims.nbDims; i++) {
36  input_dims_.d[i] = in_dims.d[i];
37  input_dims_.type[i] = in_dims.type[i];
38  }
39  axis_ = argmax_param.axis();
40  out_max_val_ = argmax_param.out_max_val();
41  top_k_ = argmax_param.top_k();
42  CHECK_GE(top_k_, static_cast<size_t>(1))
43  << "top k must not be less than 1.";
44  output_dims_ = input_dims_;
45  output_dims_.d[0] = 1;
46  if (out_max_val_) {
47  // Produces max_ind and max_val
48  output_dims_.d[0] = 2;
49  }
50  }
51 
60  virtual int initialize() { return 0; }
61  virtual void terminate() {}
62  int getNbOutputs() const override { return 1; }
63  virtual nvinfer1::Dims getOutputDimensions(int index,
64  const nvinfer1::Dims *inputs,
65  int nbInputDims) {
66  input_dims_ = inputs[0];
67  for (int i = 1; i < input_dims_.nbDims; i++) {
68  output_dims_.d[i] = input_dims_.d[i];
69  }
70  return output_dims_;
71  }
72 
73  void configure(const nvinfer1::Dims *inputDims, int nbInputs,
74  const nvinfer1::Dims *outputDims, int nbOutputs,
75  int maxBatchSize) override {
76  input_dims_ = inputDims[0];
77  for (int i = 1; i < input_dims_.nbDims; i++) {
78  output_dims_.d[i] = input_dims_.d[i];
79  }
80  }
81 
82  size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
83 
84  virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
85  void *workspace, cudaStream_t stream);
86 
87  size_t getSerializationSize() override { return 0; }
88 
89  void serialize(void *buffer) override {
90  char *d = reinterpret_cast<char *>(buffer), *a = d;
91  size_t size = getSerializationSize();
92  CHECK_EQ(d, a + size);
93  }
94 
95  virtual ~ArgMax1Plugin() {}
96 
97  private:
98  bool out_max_val_;
99  size_t top_k_;
100  int axis_;
101  float float_min_;
102  nvinfer1::Dims input_dims_;
103  nvinfer1::Dims output_dims_;
104 };
105 
106 } // namespace inference
107 } // namespace perception
108 } // namespace apollo
virtual int initialize()
get the number of outputs from the layer
Definition: argmax_plugin.h:60
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream)
size_t getSerializationSize() override
Definition: argmax_plugin.h:87
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
Definition: future.h:29
int getNbOutputs() const override
Definition: argmax_plugin.h:62
virtual void terminate()
Definition: argmax_plugin.h:61
virtual ~ArgMax1Plugin()
Definition: argmax_plugin.h:95
Definition: argmax_plugin.h:29
ArgMax1Plugin(const ArgMaxParameter &argmax_param, nvinfer1::Dims in_dims)
Definition: argmax_plugin.h:31
void configure(const nvinfer1::Dims *inputDims, int nbInputs, const nvinfer1::Dims *outputDims, int nbOutputs, int maxBatchSize) override
Definition: argmax_plugin.h:73
void serialize(void *buffer) override
Definition: argmax_plugin.h:89
size_t getWorkspaceSize(int maxBatchSize) const override
Definition: argmax_plugin.h:82
virtual nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, int nbInputDims)
Definition: argmax_plugin.h:63