Apollo  6.0
Open source self driving car software
slice_plugin.h
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright 2018 The Apollo Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *****************************************************************************/
16 
17 #pragma once
18 
19 #include <algorithm>
20 #include <vector>
21 
23 
24 namespace apollo {
25 namespace perception {
26 namespace inference {
27 
28 class SLICEPlugin : public nvinfer1::IPlugin {
29  public:
30  SLICEPlugin(const SliceParameter &param, const nvinfer1::Dims &in_dims) {
31  CHECK_GT(param.slice_point_size(), 0);
32  for (int i = 0; i < param.slice_point_size(); i++) {
33  slice_point_.push_back(param.slice_point(i));
34  }
35  axis_ = std::max(param.axis() - 1, 0);
36  input_dims_.nbDims = in_dims.nbDims;
37  CHECK_GT(input_dims_.nbDims, 0);
38  for (int i = 0; i < in_dims.nbDims; i++) {
39  input_dims_.d[i] = in_dims.d[i];
40  input_dims_.type[i] = in_dims.type[i];
41  }
42 
43  for (size_t i = 0; i < slice_point_.size(); i++) {
44  if (i == 0) {
45  out_slice_dims_.push_back(slice_point_[i]);
46  } else {
47  out_slice_dims_.push_back(slice_point_[i] - slice_point_[i - 1]);
48  }
49  }
50  out_slice_dims_.push_back(input_dims_.d[axis_] -
51  slice_point_[slice_point_.size() - 1]);
52  }
55  virtual int initialize() { return 0; }
56  virtual void terminate() {}
57  int getNbOutputs() const override {
58  return static_cast<int>(slice_point_.size()) + 1;
59  }
60  nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs,
61  int nbInputDims) override {
62  nvinfer1::Dims out_dims = inputs[0];
63  out_dims.d[axis_] = out_slice_dims_[index];
64  return out_dims;
65  }
66 
67  void configure(const nvinfer1::Dims *inputDims, int nbInputs,
68  const nvinfer1::Dims *outputDims, int nbOutputs,
69  int maxBatchSize) override {
70  input_dims_ = inputDims[0];
71  }
72 
73  size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
74 
75  virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
76  void *workspace, cudaStream_t stream);
77 
78  size_t getSerializationSize() override { return 0; }
79 
80  void serialize(void *buffer) override {
81  char *d = reinterpret_cast<char *>(buffer), *a = d;
82  size_t size = getSerializationSize();
83  CHECK_EQ(d, a + size);
84  }
85 
86  private:
87  std::vector<int> slice_point_;
88  std::vector<int> out_slice_dims_;
89  int axis_;
90  nvinfer1::Dims input_dims_;
91 };
92 
93 } // namespace inference
94 } // namespace perception
95 } // namespace apollo
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, int nbInputDims) override
Definition: slice_plugin.h:60
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs, void *workspace, cudaStream_t stream)
~SLICEPlugin()
Definition: slice_plugin.h:54
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
virtual void terminate()
Definition: slice_plugin.h:56
Definition: slice_plugin.h:28
void serialize(void *buffer) override
Definition: slice_plugin.h:80
void configure(const nvinfer1::Dims *inputDims, int nbInputs, const nvinfer1::Dims *outputDims, int nbOutputs, int maxBatchSize) override
Definition: slice_plugin.h:67
size_t getSerializationSize() override
Definition: slice_plugin.h:78
int getNbOutputs() const override
Definition: slice_plugin.h:57
virtual int initialize()
Definition: slice_plugin.h:55
SLICEPlugin(const SliceParameter &param, const nvinfer1::Dims &in_dims)
Definition: slice_plugin.h:30
size_t getWorkspaceSize(int maxBatchSize) const override
Definition: slice_plugin.h:73
SLICEPlugin()
Definition: slice_plugin.h:53