SourceXtractorPlusPlus  0.19
SourceXtractor++, the next generation SExtractor
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
SegmentationConfig.cpp
Go to the documentation of this file.
1 
23 #include <iostream>
24 #include <fstream>
25 
26 #include <boost/regex.hpp>
27 #include <boost/algorithm/string.hpp>
28 
35 
36 using boost::regex;
37 using boost::regex_match;
38 using boost::smatch;
39 
40 using namespace Euclid::Configuration;
41 namespace po = boost::program_options;
42 
43 namespace SourceXtractor {
44 
46 
47 static const std::string SEGMENTATION_ALGORITHM {"segmentation-algorithm" };
48 static const std::string SEGMENTATION_USE_FILTERING {"segmentation-use-filtering" };
49 static const std::string SEGMENTATION_FILTER {"segmentation-filter" };
50 static const std::string SEGMENTATION_LUTZ_WINDOW_SIZE {"segmentation-lutz-window-size" };
51 static const std::string SEGMENTATION_BFS_MAX_DELTA {"segmentation-bfs-max-delta" };
52 static const std::string SEGMENTATION_ML_MODEL {"segmentation-ml-model" };
53 static const std::string SEGMENTATION_ML_THRESHOLD {"segmentation-ml-threshold" };
54 
55 SegmentationConfig::SegmentationConfig(long manager_id) : Configuration(manager_id), m_selected_algorithm(Algorithm::UNKNOWN)
56  , m_lutz_window_size(0)
57  , m_bfs_max_delta(1000)
58  , m_ml_threshold(0.9) {}
59 
61  return { {"Detection image", {
62  {SEGMENTATION_ALGORITHM.c_str(), po::value<std::string>()->default_value("LUTZ"),
63  "Segmentation algorithm to be used (LUTZ, TILES or ML (a ONNX-format model must be provided))"},
64  {SEGMENTATION_USE_FILTERING.c_str(), po::value<bool>()->default_value(true),
65  "Is filtering used"},
66  {SEGMENTATION_FILTER.c_str(), po::value<std::string>()->default_value(""),
67  "Loads a filter"},
68  {SEGMENTATION_LUTZ_WINDOW_SIZE.c_str(), po::value<int>()->default_value(0),
69  "Lutz sliding window size (0=disable)"},
70  {SEGMENTATION_BFS_MAX_DELTA.c_str(), po::value<int>()->default_value(1000),
71  "BFS algorithm max source x/y size (default=1000)"},
72  {SEGMENTATION_ML_MODEL.c_str(), po::value<std::string>()->default_value(""),
73  "ONNX model to use with machine learning segmentation"},
74  {SEGMENTATION_ML_THRESHOLD.c_str(), po::value<double>()->default_value(0.9),
75  "Probability threshold for ML detection"},
76  }}};
77 }
78 
79 void SegmentationConfig::preInitialize(const UserValues& args) {
80  auto algorithm_name = boost::to_upper_copy(args.at(SEGMENTATION_ALGORITHM).as<std::string>());
81  if (algorithm_name == "LUTZ") {
83  } else if (algorithm_name == "BFS") {
85  } else if (algorithm_name == "ML") {
86 #ifdef WITH_ML_SEGMENTATION
88 #else
89  throw Elements::Exception() << "SourceXtractor++ has not been compiled with ONNX support";
90 #endif
91  } else {
92  throw Elements::Exception() << "Unknown segmentation algorithm : " << algorithm_name;
93  }
94 
95  if (args.at(SEGMENTATION_USE_FILTERING).as<bool>()) {
96  auto filter_filename = args.at(SEGMENTATION_FILTER).as<std::string>();
97  if (filter_filename != "") {
98  m_filter = loadFilter(filter_filename);
99  if (m_filter == nullptr)
100  throw Elements::Exception() << "Can not load filter: " << filter_filename;
101  } else {
103  }
104  } else {
105  m_filter = nullptr;
106  }
107 
109  m_bfs_max_delta = args.at(SEGMENTATION_BFS_MAX_DELTA).as<int>();
111  m_ml_threshold = args.at(SEGMENTATION_ML_THRESHOLD).as<double>();
112 
113  if (m_selected_algorithm == Algorithm::ML && m_onnx_model_path == "") {
114  throw Elements::Exception() << "Machine learning segmentation requested but no ONNX model was provided";
115  }
116 }
117 
119  segConfigLogger.info() << "Using the default segmentation (3x3) filter.";
120  auto convolution_kernel = VectorImage<SeFloat>::create(3, 3);
121  convolution_kernel->setValue(0,0, 1);
122  convolution_kernel->setValue(0,1, 2);
123  convolution_kernel->setValue(0,2, 1);
124 
125  convolution_kernel->setValue(1,0, 2);
126  convolution_kernel->setValue(1,1, 4);
127  convolution_kernel->setValue(1,2, 2);
128 
129  convolution_kernel->setValue(2,0, 1);
130  convolution_kernel->setValue(2,1, 2);
131  convolution_kernel->setValue(2,2, 1);
132 
133  return std::make_shared<BackgroundConvolution>(convolution_kernel, true);
134 }
135 
137  // check for the extension ".fits"
138  std::string fits_ending(".fits");
139  if (filename.length() >= fits_ending.length()
140  && filename.compare (filename.length() - fits_ending.length(), fits_ending.length(), fits_ending)==0) {
141  // load a FITS filter
142  return loadFITSFilter(filename);
143  }
144  else{
145  // load an ASCII filter
146  return loadASCIIFilter(filename);
147  }
148 }
149 
151 
152  // read in the FITS file
153  auto convolution_kernel = FitsReader<SeFloat>::readFile(filename);
154 
155  // give some feedback on the filter
156  segConfigLogger.info() << "Loaded segmentation filter: " << filename << " height: " << convolution_kernel->getHeight() << " width: " << convolution_kernel->getWidth();
157 
158  // return the correct object
159  return std::make_shared<BackgroundConvolution>(convolution_kernel, true);
160 }
161 
162 static bool getNormalization(std::istream& line_stream) {
163  std::string conv, norm_type;
164  line_stream >> conv >> norm_type;
165  if (conv != "CONV") {
166  throw Elements::Exception() << "Unexpected start for ASCII filter: " << conv;
167  }
168  if (norm_type == "NORM") {
169  return true;
170  }
171  else if (norm_type == "NONORM") {
172  return false;
173  }
174 
175  throw Elements::Exception() << "Unexpected normalization type: " << norm_type;
176 }
177 
178 template <typename T>
179 static void extractValues(std::istream& line_stream, std::vector<T>& data) {
180  T value;
181  while (line_stream.good()) {
182  line_stream >> value;
183  data.push_back(value);
184  }
185 }
186 
188  std::ifstream file;
189 
190  // open the file and check
191  file.open(filename);
192  if (!file.good() || !file.is_open()){
193  throw Elements::Exception() << "Can not load filter: " << filename;
194  }
195 
196  enum class LoadState {
197  STATE_START,
198  STATE_FIRST_LINE,
199  STATE_OTHER_LINES
200  };
201 
202  LoadState state = LoadState::STATE_START;
203  bool normalize = false;
204  std::vector<SeFloat> kernel_data;
205  size_t kernel_width = 0;
206 
207  while (file.good()) {
208  std::string line;
209  std::getline(file, line);
210  line = regex_replace(line, regex("\\s*#.*"), std::string(""));
211  line = regex_replace(line, regex("\\s*$"), std::string(""));
212  if (line.size() == 0) {
213  continue;
214  }
215 
216  std::stringstream line_stream(line);
217 
218  switch (state) {
219  case LoadState::STATE_START:
220  normalize = getNormalization(line_stream);
221  state = LoadState::STATE_FIRST_LINE;
222  break;
223  case LoadState::STATE_FIRST_LINE:
224  extractValues(line_stream, kernel_data);
225  kernel_width = kernel_data.size();
226  state = LoadState::STATE_OTHER_LINES;
227  break;
228  case LoadState::STATE_OTHER_LINES:
229  extractValues(line_stream, kernel_data);
230  break;
231  }
232  }
233 
234  // compute the dimensions and create the kernel
235  if (kernel_width == 0) {
236  throw Elements::Exception() << "Malformed segmentation filter: width is 0";
237  }
238  auto kernel_height = kernel_data.size() / kernel_width;
239  auto convolution_kernel = VectorImage<SeFloat>::create(kernel_width, kernel_height, kernel_data);
240 
241  // give some feedback on the filter
242  segConfigLogger.info() << "Loaded segmentation filter: " << filename << " width: " << convolution_kernel->getWidth() << " height: " << convolution_kernel->getHeight();
243 
244  // return the correct object
245  return std::make_shared<BackgroundConvolution>(convolution_kernel, normalize);
246 }
247 
248 } // SourceXtractor namespace
static const std::string SEGMENTATION_ALGORITHM
T open(T...args)
static const std::string SEGMENTATION_LUTZ_WINDOW_SIZE
static void extractValues(std::istream &line_stream, std::vector< T > &data)
T good(T...args)
static const std::string SEGMENTATION_FILTER
T getline(T...args)
void info(const std::string &logMessage)
static std::shared_ptr< VectorImage< T > > create(Args &&...args)
Definition: VectorImage.h:100
STL class.
std::shared_ptr< DetectionImageFrame::ImageFilter > m_filter
STL class.
static bool getNormalization(std::istream &line_stream)
STL class.
T at(T...args)
std::shared_ptr< DetectionImageFrame::ImageFilter > loadFilter(const std::string &filename) const
T push_back(T...args)
T regex_replace(T...args)
static const std::string SEGMENTATION_ML_THRESHOLD
string filename
Definition: conf.py:65
static const std::string SEGMENTATION_ML_MODEL
static Elements::Logging segConfigLogger
std::map< std::string, Configuration::OptionDescriptionList > getProgramOptions() override
std::shared_ptr< DetectionImageFrame::ImageFilter > loadFITSFilter(const std::string &filename) const
T length(T...args)
STL class.
static const std::string SEGMENTATION_BFS_MAX_DELTA
T c_str(T...args)
static const std::string SEGMENTATION_USE_FILTERING
std::shared_ptr< DetectionImageFrame::ImageFilter > loadASCIIFilter(const std::string &filename) const
T is_open(T...args)
void preInitialize(const UserValues &args) override
static Logging getLogger(const std::string &name="")
STL class.
std::shared_ptr< DetectionImageFrame::ImageFilter > getDefaultFilter() const
T compare(T...args)