mlpack  3.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
cf_model.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_CF_CF_MODEL_HPP
13 #define MLPACK_METHODS_CF_CF_MODEL_HPP
14 
15 #include <mlpack/core.hpp>
16 #include <boost/variant.hpp>
17 #include "cf.hpp"
18 
26 
27 namespace mlpack {
28 namespace cf {
29 
34 class DeleteVisitor : public boost::static_visitor<void>
35 {
36  public:
38  template<typename DecompositionPolicy>
40 };
41 
45 class GetValueVisitor : public boost::static_visitor<void*>
46 {
47  public:
49  template<typename DecompositionPolicy>
51 };
52 
57 template <typename NeighborSearchPolicy,
58  typename InterpolationPolicy>
59 class PredictVisitor : public boost::static_visitor<void>
60 {
61  private:
63  const arma::Mat<size_t>& combinations;
65  arma::vec& predictions;
66 
67  public:
69  template<typename DecompositionPolicy>
71 
73  PredictVisitor(const arma::Mat<size_t>& combinations,
74  arma::vec& predictions);
75 };
76 
81 template <typename NeighborSearchPolicy,
82  typename InterpolationPolicy>
83 class RecommendationVisitor : public boost::static_visitor<void>
84 {
85  private:
87  const size_t numRecs;
89  arma::Mat<size_t>& recommendations;
91  const arma::Col<size_t>& users;
93  const bool usersGiven;
94 
95  public:
97  RecommendationVisitor(const size_t numRecs,
98  arma::Mat<size_t>& recommendations,
99  const arma::Col<size_t>& users,
100  const bool usersGiven);
101 
103  template<typename DecompositionPolicy>
105 };
106 
110 class CFModel
111 {
112  private:
118  boost::variant<CFType<NMFPolicy>*,
126 
127  public:
129  CFModel() { }
130 
132  ~CFModel();
133 
135  template<typename DecompositionPolicy>
136  const CFType<DecompositionPolicy>* CFPtr() const;
137 
139  template<typename DecompositionPolicy,
140  typename MatType>
141  void Train(const MatType& data,
142  const size_t numUsersForSimilarity,
143  const size_t rank,
144  const size_t maxIterations,
145  const double minResidue,
146  const bool mit);
147 
149  template <typename NeighborSearchPolicy,
150  typename InterpolationPolicy>
151  void Predict(const arma::Mat<size_t>& combinations,
152  arma::vec& predictions);
153 
155  template<typename NeighborSearchPolicy,
156  typename InterpolationPolicy>
157  void GetRecommendations(const size_t numRecs,
158  arma::Mat<size_t>& recommendations,
159  const arma::Col<size_t>& users);
160 
162  template<typename NeighborSearchPolicy,
163  typename InterpolationPolicy>
164  void GetRecommendations(const size_t numRecs,
165  arma::Mat<size_t>& recommendations);
166 
168  template<typename Archive>
169  void serialize(Archive& ar, const unsigned int /* version */);
170 };
171 
172 } // namespace cf
173 } // namespace mlpack
174 
175 // Include implementation.
176 #include "cf_model_impl.hpp"
177 
178 #endif
const CFType< DecompositionPolicy > * CFPtr() const
Get the pointer to CFType&lt;&gt; object.
void Train(const MatType &data, const size_t numUsersForSimilarity, const size_t rank, const size_t maxIterations, const double minResidue, const bool mit)
Train the model.
RecommendationVisitor(const size_t numRecs, arma::Mat< size_t > &recommendations, const arma::Col< size_t > &users, const bool usersGiven)
Visitor constructor.
void * operator()(CFType< DecompositionPolicy > *c) const
Return stored pointer as void* type.
RecommendationVisitor uses the CFType object to get recommendations for the given users...
Definition: cf_model.hpp:83
void serialize(Archive &ar, const unsigned int)
Serialize the model.
~CFModel()
Clean up memory.
PredictVisitor uses the CFType object to make predictions on the given combinations of users and item...
Definition: cf_model.hpp:59
void GetRecommendations(const size_t numRecs, arma::Mat< size_t > &recommendations, const arma::Col< size_t > &users)
Compute recommendations for query users.
DeleteVisitor deletes the CFType&lt;&gt; object which is pointed to by the variable cf in class CFModel...
Definition: cf_model.hpp:34
The model to save to disk.
Definition: cf_model.hpp:110
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void operator()(CFType< DecompositionPolicy > *c) const
Generates the given number of recommendations.
GetValueVisitor returns the pointer which points to the CFType object.
Definition: cf_model.hpp:45
void Predict(const arma::Mat< size_t > &combinations, arma::vec &predictions)
Make predictions.
CFModel()
Create an empty CF model.
Definition: cf_model.hpp:129
This class implements Collaborative Filtering (CF).
Definition: cf.hpp:70
void operator()(CFType< DecompositionPolicy > *c) const
Delete CFType object.
void operator()(CFType< DecompositionPolicy > *c) const
Predict ratings for each user-item combination.
PredictVisitor(const arma::Mat< size_t > &combinations, arma::vec &predictions)
Visitor constructor.