22 #include "torch/script.h" 23 #include "torch/torch.h" 29 namespace prediction {
58 std::vector<Obstacle*> dynamic_env)
override;
67 const LaneGraph* lane_graph_ptr,
68 std::vector<double>* feature_values);
71 const LaneGraph& lane_graph,
72 std::vector<std::string>*
const string_feature_values);
77 std::string
GetName()
override {
return "LANE_SCANNING_EVALUATOR"; }
90 bool ExtractObstacleFeatures(
const Obstacle* obstacle_ptr,
91 std::vector<double>* feature_values);
98 bool ExtractStaticEnvFeatures(
const Obstacle* obstacle_ptr,
99 const LaneGraph* lane_graph_ptr,
100 std::vector<double>* feature_values,
101 std::vector<int>* lane_sequence_idx_to_remove);
103 void ModelInference(
const std::vector<torch::jit::IValue>& torch_inputs,
104 torch::jit::script::Module torch_model,
105 Feature* feature_ptr);
108 static const size_t OBSTACLE_FEATURE_SIZE = 20 * (9 + 40);
109 static const size_t INTERACTION_FEATURE_SIZE = 8;
110 static const size_t SINGLE_LANE_FEATURE_SIZE = 4;
111 static const size_t LANE_POINTS_SIZE = 100;
112 static const size_t BACKWARD_LANE_POINTS_SIZE = 50;
113 static const size_t MAX_NUM_LANE = 10;
114 static const size_t SHORT_TERM_TRAJECTORY_SIZE = 10;
116 torch::jit::script::Module torch_lane_scanning_model_;
117 torch::Device device_;
Prediction obstacle.
Definition: obstacle.h:52
Definition: obstacles_container.h:39
PlanningContext is the runtime context in planning. It is persistent across multiple frames...
Definition: atomic_hash_map.h:25
std::string GetName() override
Get the name of evaluator.
Definition: lane_scanning_evaluator.h:77
Definition: lane_scanning_evaluator.h:31
bool ExtractStringFeatures(const LaneGraph &lane_graph, std::vector< std::string > *const string_feature_values)
virtual ~LaneScanningEvaluator()=default
Destructor.
LaneScanningEvaluator()
Constructor.
bool Evaluate(Obstacle *obstacle_ptr, ObstaclesContainer *obstacles_container) override
Override Evaluate.
Definition: evaluator.h:39
Define the data container base class.
bool ExtractFeatures(const Obstacle *obstacle_ptr, const LaneGraph *lane_graph_ptr, std::vector< double > *feature_values)
Extract features for learning model's input.