8 #ifndef _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_
9 #define _SEIMPLEMENTATION_PLUGIN_FLEXIBLEMODELFITTING_ONNXCOMPACTMODEL_H_
20 namespace ModelFitting {
24 template <
typename ImageType>
34 :
CompactModelBase<ImageType>(x_scale, y_scale, rotation, width, height, x, y, transform),
41 double getValue(
double,
double)
const override {
47 ImageType image = Traits::factory(size_x, size_y);
49 int largest_size =
std::max(size_x, size_y);
53 auto shape = model->getOutputShape();
54 if (largest_size < shape[2]) {
55 selected_model = model;
60 if (selected_model ==
nullptr) {
61 logger.
warn() <<
"No large enough ONNX model could be found, skipping...";
65 auto input_shape = selected_model->getInputShape();
66 auto output_shape = selected_model->getOutputShape();
67 int render_size = output_shape[2];
69 if (selected_model->getOutputType() != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
78 input_data_arrays[it.first] =
std::vector<float>( {
static_cast<float>(it.second->getValue()) } );
86 for (
int y=0;
y<(int)size_y; ++
y) {
87 int dy =
y - size_y / 2;
88 for (
int x=0;
x<(int)size_x; ++
x) {
89 int dx =
x - size_x / 2;
91 float x2 = dx * transform[0] + dy * transform[1];
92 float y2 = dx * transform[2] + dy * transform[3];
94 input_data_arrays[
"x"][
x +
y * render_size] = x2;
95 input_data_arrays[
"y"][
x +
y * render_size] = y2;
99 selected_model->runMultiInput<float,
float>(input_data_arrays, output_data);
101 for (
int y = 0;
y < (int) size_y; ++
y) {
102 for (
int x = 0;
x < (int) size_x; ++
x) {
103 Traits::at(image,
x,
y) = output_data[
x +
y * render_size];
std::shared_ptr< BasicParameter > m_flux
std::map< std::string, std::shared_ptr< BasicParameter > > m_params
static Elements::Logging logger
Mat22 getCombinedTransform(double pixel_scale) const
void warn(const std::string &logMessage)
std::vector< std::shared_ptr< SourceXtractor::OnnxModel > > m_models
void renormalize(ImageType &image, double flux) const
ImageType getRasterizedImage(double pixel_scale, std::size_t size_x, std::size_t size_y) const override
OnnxCompactModel(std::vector< std::shared_ptr< SourceXtractor::OnnxModel >> models, std::shared_ptr< BasicParameter > x_scale, std::shared_ptr< BasicParameter > y_scale, std::shared_ptr< BasicParameter > rotation, double width, double height, std::shared_ptr< BasicParameter > x, std::shared_ptr< BasicParameter > y, std::shared_ptr< BasicParameter > flux, std::map< std::string, std::shared_ptr< BasicParameter >> params, std::tuple< double, double, double, double > transform)
double getValue(double, double) const override
std::shared_ptr< EngineParameter > dx
static Logging getLogger(const std::string &name="")
virtual ~OnnxCompactModel()=default
std::shared_ptr< EngineParameter > dy