27 #include "torch/script.h" 28 #include "torch/torch.h" 38 namespace prediction {
63 std::string
GetName()
override {
return "LANE_AGGREGATING_EVALUATOR"; }
71 bool ExtractObstacleFeatures(
const Obstacle* obstacle_ptr,
72 std::vector<double>* feature_values);
79 bool ExtractStaticEnvFeatures(
80 const Obstacle* obstacle_ptr,
const LaneGraph* lane_graph_ptr,
81 std::vector<std::vector<double>>* feature_values,
82 std::vector<int>* lane_sequence_idx_to_remove);
84 torch::Tensor AggregateLaneEncodings(
85 const std::vector<torch::Tensor>& lane_encoding_list);
87 torch::Tensor LaneEncodingMaxPooling(
88 const std::vector<torch::Tensor>& lane_encoding_list);
90 torch::Tensor LaneEncodingAvgPooling(
91 const std::vector<torch::Tensor>& lane_encoding_list);
93 std::vector<double> StableSoftmax(
94 const std::vector<double>& prediction_scores);
99 torch::jit::script::Module torch_obstacle_encoding_;
100 torch::jit::script::Module torch_lane_encoding_;
101 torch::jit::script::Module torch_prediction_layer_;
102 torch::Device device_;
104 static const size_t OBSTACLE_FEATURE_SIZE = 20 * 9;
105 static const size_t SINGLE_LANE_FEATURE_SIZE = 4;
106 static const size_t LANE_POINTS_SIZE = 100;
107 static const size_t BACKWARD_LANE_POINTS_SIZE = 50;
109 static const size_t OBSTACLE_ENCODING_SIZE = 128;
110 static const size_t SINGLE_LANE_ENCODING_SIZE = 128;
111 static const size_t AGGREGATED_ENCODING_SIZE = 256;
LaneAggregatingEvaluator()
Constructor.
Prediction obstacle.
Definition: obstacle.h:52
bool Evaluate(Obstacle *obstacle_ptr, ObstaclesContainer *obstacles_container) override
Override Evaluate.
Definition: obstacles_container.h:39
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
virtual ~LaneAggregatingEvaluator()=default
Destructor.
std::string GetName() override
Get the name of evaluator.
Definition: lane_aggregating_evaluator.h:63
Definition: evaluator.h:39
Definition: lane_aggregating_evaluator.h:40
Define the data container base class.