mlpack  3.0.4
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
random_forest.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
14 
17 #include "bootstrap.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename FitnessFunction = GiniGain,
23  typename DimensionSelectionType = MultipleRandomDimensionSelect<>,
24  template<typename> class NumericSplitType = BestBinaryNumericSplit,
25  template<typename> class CategoricalSplitType = AllCategoricalSplit,
26  typename ElemType = double>
28 {
29  public:
31  typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
32  DimensionSelectionType, ElemType> DecisionTreeType;
33 
39 
51  template<typename MatType>
52  RandomForest(const MatType& dataset,
53  const arma::Row<size_t>& labels,
54  const size_t numClasses,
55  const size_t numTrees = 50,
56  const size_t minimumLeafSize = 20);
57 
71  template<typename MatType>
72  RandomForest(const MatType& dataset,
73  const data::DatasetInfo& datasetInfo,
74  const arma::Row<size_t>& labels,
75  const size_t numClasses,
76  const size_t numTrees = 50,
77  const size_t minimumLeafSize = 20);
78 
91  template<typename MatType>
92  RandomForest(const MatType& dataset,
93  const arma::Row<size_t>& labels,
94  const size_t numClasses,
95  const arma::rowvec& weights,
96  const size_t numTrees = 50,
97  const size_t minimumLeafSize = 20);
98 
113  template<typename MatType>
114  RandomForest(const MatType& dataset,
115  const data::DatasetInfo& datasetInfo,
116  const arma::Row<size_t>& labels,
117  const size_t numClasses,
118  const arma::rowvec& weights,
119  const size_t numTrees = 50,
120  const size_t minimumLeafSize = 20);
121 
133  template<typename MatType>
134  void Train(const MatType& data,
135  const arma::Row<size_t>& labels,
136  const size_t numClasses,
137  const size_t numTrees = 50,
138  const size_t minimumLeafSize = 20);
139 
153  template<typename MatType>
154  void Train(const MatType& data,
155  const data::DatasetInfo& datasetInfo,
156  const arma::Row<size_t>& labels,
157  const size_t numClasses,
158  const size_t numTrees = 50,
159  const size_t minimumLeafSize = 20);
160 
173  template<typename MatType>
174  void Train(const MatType& data,
175  const arma::Row<size_t>& labels,
176  const size_t numClasses,
177  const arma::rowvec& weights,
178  const size_t numTrees = 50,
179  const size_t minimumLeafSize = 20);
180 
195  template<typename MatType>
196  void Train(const MatType& data,
197  const data::DatasetInfo& datasetInfo,
198  const arma::Row<size_t>& labels,
199  const size_t numClasses,
200  const arma::rowvec& weights,
201  const size_t numTrees = 50,
202  const size_t minimumLeafSize = 20);
203 
210  template<typename VecType>
211  size_t Classify(const VecType& point) const;
212 
222  template<typename VecType>
223  void Classify(const VecType& point,
224  size_t& prediction,
225  arma::vec& probabilities) const;
226 
234  template<typename MatType>
235  void Classify(const MatType& data,
236  arma::Row<size_t>& predictions) const;
237 
247  template<typename MatType>
248  void Classify(const MatType& data,
249  arma::Row<size_t>& predictions,
250  arma::mat& probabilities) const;
251 
253  const DecisionTreeType& Tree(const size_t i) const { return trees[i]; }
255  DecisionTreeType& Tree(const size_t i) { return trees[i]; }
256 
258  size_t NumTrees() const { return trees.size(); }
259 
263  template<typename Archive>
264  void serialize(Archive& ar, const unsigned int /* version */);
265 
266  private:
283  template<bool UseWeights, bool UseDatasetInfo, typename MatType>
284  void Train(const MatType& data,
285  const data::DatasetInfo& datasetInfo,
286  const arma::Row<size_t>& labels,
287  const size_t numClasses,
288  const arma::rowvec& weights,
289  const size_t numTrees,
290  const size_t minimumLeafSize);
291 
293  std::vector<DecisionTreeType> trees;
294 };
295 
296 } // namespace tree
297 } // namespace mlpack
298 
299 // Include implementation.
300 #include "random_forest_impl.hpp"
301 
302 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
This class implements a generic decision tree learner.
void Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=50, const size_t minimumLeafSize=20)
Train the random forest on the given labeled training data with the given number of trees...
size_t NumTrees() const
Get the number of trees in the forest.
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.