12 #ifndef MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
13 #define MLPACK_METHODS_RANDOM_FOREST_RANDOM_FOREST_HPP
22 template<
typename FitnessFunction = GiniGain,
23 typename DimensionSelectionType = MultipleRandomDimensionSelect,
24 template<
typename>
class NumericSplitType = BestBinaryNumericSplit,
25 template<
typename>
class CategoricalSplitType = AllCategoricalSplit,
26 typename ElemType =
double>
31 typedef DecisionTree<FitnessFunction, NumericSplitType, CategoricalSplitType,
55 template<
typename MatType>
57 const arma::Row<size_t>& labels,
58 const size_t numClasses,
59 const size_t numTrees = 20,
60 const size_t minimumLeafSize = 1,
61 const double minimumGainSplit = 1e-7,
62 DimensionSelectionType dimensionSelector =
63 DimensionSelectionType());
82 template<
typename MatType>
85 const arma::Row<size_t>& labels,
86 const size_t numClasses,
87 const size_t numTrees = 20,
88 const size_t minimumLeafSize = 1,
89 const double minimumGainSplit = 1e-7,
90 DimensionSelectionType dimensionSelector =
91 DimensionSelectionType());
105 template<
typename MatType>
107 const arma::Row<size_t>& labels,
108 const size_t numClasses,
109 const arma::rowvec& weights,
110 const size_t numTrees = 20,
111 const size_t minimumLeafSize = 1,
112 const double minimumGainSplit = 1e-7,
113 DimensionSelectionType dimensionSelector =
114 DimensionSelectionType());
134 template<
typename MatType>
137 const arma::Row<size_t>& labels,
138 const size_t numClasses,
139 const arma::rowvec& weights,
140 const size_t numTrees = 20,
141 const size_t minimumLeafSize = 1,
142 const double minimumGainSplit = 1e-7,
143 DimensionSelectionType dimensionSelector =
144 DimensionSelectionType());
162 template<
typename MatType>
163 double Train(
const MatType& data,
164 const arma::Row<size_t>& labels,
165 const size_t numClasses,
166 const size_t numTrees = 20,
167 const size_t minimumLeafSize = 1,
168 const double minimumGainSplit = 1e-7,
169 DimensionSelectionType dimensionSelector =
170 DimensionSelectionType());
191 template<
typename MatType>
192 double Train(
const MatType& data,
194 const arma::Row<size_t>& labels,
195 const size_t numClasses,
196 const size_t numTrees = 20,
197 const size_t minimumLeafSize = 1,
198 const double minimumGainSplit = 1e-7,
199 DimensionSelectionType dimensionSelector =
200 DimensionSelectionType());
219 template<
typename MatType>
220 double Train(
const MatType& data,
221 const arma::Row<size_t>& labels,
222 const size_t numClasses,
223 const arma::rowvec& weights,
224 const size_t numTrees = 20,
225 const size_t minimumLeafSize = 1,
226 const double minimumGainSplit = 1e-7,
227 DimensionSelectionType dimensionSelector =
228 DimensionSelectionType());
249 template<
typename MatType>
250 double Train(
const MatType& data,
252 const arma::Row<size_t>& labels,
253 const size_t numClasses,
254 const arma::rowvec& weights,
255 const size_t numTrees = 20,
256 const size_t minimumLeafSize = 1,
257 const double minimumGainSplit = 1e-7,
258 DimensionSelectionType dimensionSelector =
259 DimensionSelectionType());
267 template<
typename VecType>
268 size_t Classify(
const VecType& point)
const;
279 template<
typename VecType>
282 arma::vec& probabilities)
const;
291 template<
typename MatType>
293 arma::Row<size_t>& predictions)
const;
304 template<
typename MatType>
306 arma::Row<size_t>& predictions,
307 arma::mat& probabilities)
const;
320 template<
typename Archive>
321 void serialize(Archive& ar,
const unsigned int );
343 template<
bool UseWeights,
bool UseDatasetInfo,
typename MatType>
344 double Train(
const MatType& data,
346 const arma::Row<size_t>& labels,
347 const size_t numClasses,
348 const arma::rowvec& weights,
349 const size_t numTrees,
350 const size_t minimumLeafSize,
351 const double minimumGainSplit,
352 DimensionSelectionType& dimensionSelector);
355 std::vector<DecisionTreeType> trees;
362 #include "random_forest_impl.hpp"
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
const DecisionTreeType & Tree(const size_t i) const
Access a tree in the forest.
size_t Classify(const VecType &point) const
Predict the class of the given point.
This class implements a generic decision tree learner.
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const size_t numTrees=20, const size_t minimumLeafSize=1, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the random forest on the given labeled training data with the given number of trees...
size_t NumTrees() const
Get the number of trees in the forest.
RandomForest()
Construct the random forest without any training or specifying the number of trees.
DecisionTreeType & Tree(const size_t i)
Modify a tree in the forest (be careful!).
void serialize(Archive &ar, const unsigned int)
Serialize the random forest.
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType > DecisionTreeType
Allow access to the underlying decision tree type.