21 #include "modules/perception/proto/rt.pb.h" 26 namespace perception {
31 ArgMax1Plugin(
const ArgMaxParameter &argmax_param, nvinfer1::Dims in_dims)
32 : float_min_(
std::numeric_limits<float>::min()) {
33 input_dims_.nbDims = in_dims.nbDims;
34 CHECK_GT(input_dims_.nbDims, 0);
35 for (
int i = 0; i < in_dims.nbDims; i++) {
36 input_dims_.d[i] = in_dims.d[i];
37 input_dims_.type[i] = in_dims.type[i];
39 axis_ = argmax_param.axis();
40 out_max_val_ = argmax_param.out_max_val();
41 top_k_ = argmax_param.top_k();
42 CHECK_GE(top_k_, static_cast<size_t>(1))
43 <<
"top k must not be less than 1.";
44 output_dims_ = input_dims_;
45 output_dims_.d[0] = 1;
48 output_dims_.d[0] = 2;
64 const nvinfer1::Dims *inputs,
66 input_dims_ = inputs[0];
67 for (
int i = 1; i < input_dims_.nbDims; i++) {
68 output_dims_.d[i] = input_dims_.d[i];
73 void configure(
const nvinfer1::Dims *inputDims,
int nbInputs,
74 const nvinfer1::Dims *outputDims,
int nbOutputs,
75 int maxBatchSize)
override {
76 input_dims_ = inputDims[0];
77 for (
int i = 1; i < input_dims_.nbDims; i++) {
78 output_dims_.d[i] = input_dims_.d[i];
84 virtual int enqueue(
int batchSize,
const void *
const *inputs,
void **outputs,
85 void *workspace, cudaStream_t stream);
90 char *d =
reinterpret_cast<char *
>(buffer), *a = d;
92 CHECK_EQ(d, a + size);
102 nvinfer1::Dims input_dims_;
103 nvinfer1::Dims output_dims_;
virtual int initialize()
get the number of outputs from the layer
Definition: argmax_plugin.h:60
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream)
size_t getSerializationSize() override
Definition: argmax_plugin.h:87
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
int getNbOutputs() const override
Definition: argmax_plugin.h:62
virtual void terminate()
Definition: argmax_plugin.h:61
virtual ~ArgMax1Plugin()
Definition: argmax_plugin.h:95
Definition: argmax_plugin.h:29
ArgMax1Plugin(const ArgMaxParameter &argmax_param, nvinfer1::Dims in_dims)
Definition: argmax_plugin.h:31
void configure(const nvinfer1::Dims *inputDims, int nbInputs, const nvinfer1::Dims *outputDims, int nbOutputs, int maxBatchSize) override
Definition: argmax_plugin.h:73
void serialize(void *buffer) override
Definition: argmax_plugin.h:89
size_t getWorkspaceSize(int maxBatchSize) const override
Definition: argmax_plugin.h:82
virtual nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, int nbInputDims)
Definition: argmax_plugin.h:63