SourceXtractorPlusPlus  0.19
SourceXtractor++, the next generation SExtractor
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
OnnxTaskFactory.cpp
Go to the documentation of this file.
1 
18 #include <onnxruntime_cxx_api.h>
19 
21 #include <NdArray/NdArray.h>
22 
24 
29 
31 
32 namespace SourceXtractor {
33 
38  std::stringstream prop_name;
39 
40  std::string domain = model.getDomain();
41  if (!domain.empty()) {
42  prop_name << domain << '.';
43  }
44 
45  std::string graph_name = model.getGraphName();
46  if (!graph_name.empty()) {
47  prop_name << graph_name << '.';
48  }
49 
50  prop_name << model.getOutputName();
51 
52  return prop_name.str();
53 }
54 
56 
58  if (property_id == PropertyId::create<OnnxProperty>()) {
59  return std::make_shared<OnnxSourceTask>(m_model_infos);
60  }
61  return nullptr;
62 }
63 
66 }
67 
69  const auto& onnx_config = manager.getConfiguration<OnnxConfig>();
70  const auto& models = onnx_config.getModels();
71 
72  for (auto model_path : models) {
73  auto model = std::make_shared<OnnxModel>(model_path);
74 
75  if (model->getInputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
76  throw Elements::Exception() << "Only ONNX models with float input are supported";
77  }
78 
79  if (model->getInputShape().size() != 4) {
80  throw Elements::Exception() << "Expected 4 axes for the input layer, got " << model->getInputShape().size();
81  }
82 
83  auto prop_name = generatePropertyName(*model);
84  onnx_logger.info() << "Output name will be " << prop_name;
85 
86  m_model_infos.emplace_back(OnnxSourceTask::OnnxModelInfo {model, prop_name});
87 
88  }
89 }
90 
91 template<typename T>
92 static void registerColumnConverter(OutputRegistry& registry, const OnnxSourceTask::OnnxModelInfo& model_info) {
93  auto key = model_info.prop_name;
94 
96  model_info.prop_name, [key](const OnnxProperty& prop) {
97  return prop.getData<T>(key);
98  }, "", model_info.model->getModelPath()
99  );
100 }
101 
103  for (const auto& model_info : m_model_infos) {
104  switch (model_info.model->getOutputType()) {
105  case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
106  registerColumnConverter<float>(registry, model_info);
107  break;
108  case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
109  registerColumnConverter<int32_t>(registry, model_info);
110  break;
111  default:
112  throw Elements::Exception() << "Unsupported output type: " << model_info.model->getOutputType();
113  }
114  }
115 }
116 
117 } // end of namespace SourceXtractor
T empty(T...args)
std::string getOutputName() const
Definition: OnnxModel.h:134
Elements::Logging onnx_logger
Logger for the ONNX plugin.
Definition: OnnxPlugin.cpp:26
const std::vector< std::string > & getModels() const
Definition: OnnxConfig.h:44
void info(const std::string &logMessage)
static void registerColumnConverter(OutputRegistry &registry, const OnnxSourceTask::OnnxModelInfo &model_info)
void configure(Euclid::Configuration::ConfigManager &manager) override
Method which should initialize the object.
std::vector< OnnxSourceTask::OnnxModelInfo > m_model_infos
STL class.
std::string getDomain() const
Definition: OnnxModel.h:122
std::string getGraphName() const
Definition: OnnxModel.h:126
T str(T...args)
std::shared_ptr< Task > createTask(const PropertyId &property_id) const override
Returns a Task producing a Property corresponding to the given PropertyId.
void reportConfigDependencies(Euclid::Configuration::ConfigManager &manager) const override
Registers all the Configuration dependencies.
void registerPropertyInstances(OutputRegistry &registry) override
Identifier used to set and retrieve properties.
Definition: PropertyId.h:40
void registerColumnConverter(std::string column_name, ColumnConverter< PropertyType, OutType > converter, std::string column_unit="", std::string column_description="")
static std::string generatePropertyName(const OnnxModel &model)