mlpack  3.0.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
adaboost_model.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ADABOOST_ADABOOST_MODEL_HPP
13 #define MLPACK_METHODS_ADABOOST_ADABOOST_MODEL_HPP
14 
15 #include <mlpack/core.hpp>
16 
17 // Use forward declaration instead of include to accelerate compilation.
18 class AdaBoost;
19 
20 namespace mlpack {
21 namespace adaboost {
22 
27 {
28  public:
30  {
33  };
34 
35  private:
37  arma::Col<size_t> mappings;
39  size_t weakLearnerType;
45  size_t dimensionality;
46 
47  public:
49  AdaBoostModel();
50 
52  AdaBoostModel(const arma::Col<size_t>& mappings,
53  const size_t weakLearnerType);
54 
56  AdaBoostModel(const AdaBoostModel& other);
57 
60 
62  AdaBoostModel& operator=(const AdaBoostModel& other);
63 
66 
68  const arma::Col<size_t>& Mappings() const { return mappings; }
70  arma::Col<size_t>& Mappings() { return mappings; }
71 
73  size_t WeakLearnerType() const { return weakLearnerType; }
75  size_t& WeakLearnerType() { return weakLearnerType; }
76 
78  size_t Dimensionality() const { return dimensionality; }
80  size_t& Dimensionality() { return dimensionality; }
81 
83  void Train(const arma::mat& data,
84  const arma::Row<size_t>& labels,
85  const size_t numClasses,
86  const size_t iterations,
87  const double tolerance);
88 
90  void Classify(const arma::mat& testData, arma::Row<size_t>& predictions);
91 
93  template<typename Archive>
94  void serialize(Archive& ar, const unsigned int /* version */)
95  {
96  if (Archive::is_loading::value)
97  {
98  if (dsBoost)
99  delete dsBoost;
100  if (pBoost)
101  delete pBoost;
102 
103  dsBoost = NULL;
104  pBoost = NULL;
105  }
106 
107  ar & BOOST_SERIALIZATION_NVP(mappings);
108  ar & BOOST_SERIALIZATION_NVP(weakLearnerType);
109  if (weakLearnerType == WeakLearnerTypes::DECISION_STUMP)
110  ar & BOOST_SERIALIZATION_NVP(dsBoost);
111  else if (weakLearnerType == WeakLearnerTypes::PERCEPTRON)
112  ar & BOOST_SERIALIZATION_NVP(pBoost);
113  ar & BOOST_SERIALIZATION_NVP(dimensionality);
114  }
115 };
116 
117 } // namespace adaboost
118 } // namespace mlpack
119 
120 #endif
~AdaBoostModel()
Clean up memory.
size_t Dimensionality() const
Get the dimensionality of the model.
void Classify(const arma::mat &testData, arma::Row< size_t > &predictions)
Classify test points.
The AdaBoost class.
Definition: adaboost.hpp:81
size_t & Dimensionality()
Modify the dimensionality of the model.
void serialize(Archive &ar, const unsigned int)
Serialize the model.
The model to save to disk.
arma::Col< size_t > & Mappings()
Modify the mappings.
void Train(const arma::mat &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t iterations, const double tolerance)
Train the model.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
AdaBoostModel()
Create an empty AdaBoost model.
AdaBoostModel & operator=(const AdaBoostModel &other)
Copy assignment operator.
size_t & WeakLearnerType()
Modify the weak learner type.
size_t WeakLearnerType() const
Get the weak learner type.
const arma::Col< size_t > & Mappings() const
Get the mappings.