mlpack  3.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
random_forest.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
14 
17 #include "bootstrap.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename FitnessFunction = GiniGain,
23  typename DimensionSelectionType = MultipleRandomDimensionSelect,
24  template<typename> class NumericSplitType = BestBinaryNumericSplit,
25  template<typename> class CategoricalSplitType = AllCategoricalSplit,
26  typename ElemType = double>
28 {
29  public:
31  typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
32  DimensionSelectionType, ElemType> DecisionTreeType;
33 
39 
55  template<typename MatType>
56  RandomForest(const MatType& dataset,
57  const arma::Row<size_t>& labels,
58  const size_t numClasses,
59  const size_t numTrees = 20,
60  const size_t minimumLeafSize = 1,
61  const double minimumGainSplit = 1e-7,
62  DimensionSelectionType dimensionSelector =
63  DimensionSelectionType());
64 
82  template<typename MatType>
83  RandomForest(const MatType& dataset,
84  const data::DatasetInfo& datasetInfo,
85  const arma::Row<size_t>& labels,
86  const size_t numClasses,
87  const size_t numTrees = 20,
88  const size_t minimumLeafSize = 1,
89  const double minimumGainSplit = 1e-7,
90  DimensionSelectionType dimensionSelector =
91  DimensionSelectionType());
92 
105  template<typename MatType>
106  RandomForest(const MatType& dataset,
107  const arma::Row<size_t>& labels,
108  const size_t numClasses,
109  const arma::rowvec& weights,
110  const size_t numTrees = 20,
111  const size_t minimumLeafSize = 1,
112  const double minimumGainSplit = 1e-7,
113  DimensionSelectionType dimensionSelector =
114  DimensionSelectionType());
115 
134  template<typename MatType>
135  RandomForest(const MatType& dataset,
136  const data::DatasetInfo& datasetInfo,
137  const arma::Row<size_t>& labels,
138  const size_t numClasses,
139  const arma::rowvec& weights,
140  const size_t numTrees = 20,
141  const size_t minimumLeafSize = 1,
142  const double minimumGainSplit = 1e-7,
143  DimensionSelectionType dimensionSelector =
144  DimensionSelectionType());
145 
162  template<typename MatType>
163  double Train(const MatType& data,
164  const arma::Row<size_t>& labels,
165  const size_t numClasses,
166  const size_t numTrees = 20,
167  const size_t minimumLeafSize = 1,
168  const double minimumGainSplit = 1e-7,
169  DimensionSelectionType dimensionSelector =
170  DimensionSelectionType());
171 
191  template<typename MatType>
192  double Train(const MatType& data,
193  const data::DatasetInfo& datasetInfo,
194  const arma::Row<size_t>& labels,
195  const size_t numClasses,
196  const size_t numTrees = 20,
197  const size_t minimumLeafSize = 1,
198  const double minimumGainSplit = 1e-7,
199  DimensionSelectionType dimensionSelector =
200  DimensionSelectionType());
201 
219  template<typename MatType>
220  double Train(const MatType& data,
221  const arma::Row<size_t>& labels,
222  const size_t numClasses,
223  const arma::rowvec& weights,
224  const size_t numTrees = 20,
225  const size_t minimumLeafSize = 1,
226  const double minimumGainSplit = 1e-7,
227  DimensionSelectionType dimensionSelector =
228  DimensionSelectionType());
229 
249  template<typename MatType>
250  double Train(const MatType& data,
251  const data::DatasetInfo& datasetInfo,
252  const arma::Row<size_t>& labels,
253  const size_t numClasses,
254  const arma::rowvec& weights,
255  const size_t numTrees = 20,
256  const size_t minimumLeafSize = 1,
257  const double minimumGainSplit = 1e-7,
258  DimensionSelectionType dimensionSelector =
259  DimensionSelectionType());
260 
267  template<typename VecType>
268  size_t Classify(const VecType& point) const;
269 
279  template<typename VecType>
280  void Classify(const VecType& point,
281  size_t& prediction,
282  arma::vec& probabilities) const;
283 
291  template<typename MatType>
292  void Classify(const MatType& data,
293  arma::Row<size_t>& predictions) const;
294 
304  template<typename MatType>
305  void Classify(const MatType& data,
306  arma::Row<size_t>& predictions,
307  arma::mat& probabilities) const;
308 
310  const DecisionTreeType& Tree(const size_t i) const { return trees[i]; }
312  DecisionTreeType& Tree(const size_t i) { return trees[i]; }
313 
315  size_t NumTrees() const { return trees.size(); }
316 
320  template<typename Archive>
321  void serialize(Archive& ar, const unsigned int /* version */);
322 
323  private:
343  template<bool UseWeights, bool UseDatasetInfo, typename MatType>
344  double Train(const MatType& data,
345  const data::DatasetInfo& datasetInfo,
346  const arma::Row<size_t>& labels,
347  const size_t numClasses,
348  const arma::rowvec& weights,
349  const size_t numTrees,
350  const size_t minimumLeafSize,
351  const double minimumGainSplit,
352  DimensionSelectionType& dimensionSelector);
353 
355  std::vector<DecisionTreeType> trees;
356 };
357 
358 } // namespace tree
359 } // namespace mlpack
360 
361 // Include implementation.
362 #include "random_forest_impl.hpp"
363 
364 #endif
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
This class implements a generic decision tree learner.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...
size_t NumTrees() const
Get the number of trees in the forest.
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.