mlpack  3.4.2
ns_model.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17 
23 #include <boost/variant.hpp>
24 #include "neighbor_search.hpp"
25 
26 namespace mlpack {
27 namespace neighbor {
28 
32 template<typename SortPolicy,
33  template<typename TreeMetricType,
34  typename TreeStatType,
35  typename TreeMatType> class TreeType>
36 using NSType = NeighborSearch<SortPolicy,
38  arma::mat,
39  TreeType,
41  NeighborSearchStat<SortPolicy>,
42  arma::mat>::template DualTreeTraverser>;
43 
48 class MonoSearchVisitor : public boost::static_visitor<void>
49 {
50  private:
52  const size_t k;
54  arma::Mat<size_t>& neighbors;
56  arma::mat& distances;
57 
58  public:
60  template<typename NSType>
61  void operator()(NSType* ns) const;
62 
64  MonoSearchVisitor(const size_t k,
65  arma::Mat<size_t>& neighbors,
66  arma::mat& distances) :
67  k(k),
68  neighbors(neighbors),
69  distances(distances)
70  {};
71 };
72 
79 template<typename SortPolicy>
80 class BiSearchVisitor : public boost::static_visitor<void>
81 {
82  private:
84  const arma::mat& querySet;
86  const size_t k;
88  arma::Mat<size_t>& neighbors;
90  arma::mat& distances;
92  const size_t leafSize;
94  const double tau;
96  const double rho;
97 
99  template<typename NSType>
100  void SearchLeaf(NSType* ns) const;
101 
102  public:
104  template<template<typename TreeMetricType,
105  typename TreeStatType,
106  typename TreeMatType> class TreeType>
108 
110  template<template<typename TreeMetricType,
111  typename TreeStatType,
112  typename TreeMatType> class TreeType>
113  void operator()(NSTypeT<TreeType>* ns) const;
114 
116  void operator()(NSTypeT<tree::KDTree>* ns) const;
117 
119  void operator()(NSTypeT<tree::BallTree>* ns) const;
120 
122  void operator()(SpillKNN* ns) const;
123 
125  void operator()(NSTypeT<tree::Octree>* ns) const;
126 
128  BiSearchVisitor(const arma::mat& querySet,
129  const size_t k,
130  arma::Mat<size_t>& neighbors,
131  arma::mat& distances,
132  const size_t leafSize,
133  const double tau,
134  const double rho);
135 };
136 
143 template<typename SortPolicy>
144 class TrainVisitor : public boost::static_visitor<void>
145 {
146  private:
148  arma::mat&& referenceSet;
150  size_t leafSize;
152  const double tau;
154  const double rho;
155 
157  template<typename NSType>
158  void TrainLeaf(NSType* ns) const;
159 
160  public:
162  template<template<typename TreeMetricType,
163  typename TreeStatType,
164  typename TreeMatType> class TreeType>
166 
168  template<template<typename TreeMetricType,
169  typename TreeStatType,
170  typename TreeMatType> class TreeType>
171  void operator()(NSTypeT<TreeType>* ns) const;
172 
174  void operator()(NSTypeT<tree::KDTree>* ns) const;
175 
177  void operator()(NSTypeT<tree::BallTree>* ns) const;
178 
180  void operator()(SpillKNN* ns) const;
181 
183  void operator()(NSTypeT<tree::Octree>* ns) const;
184 
187  TrainVisitor(arma::mat&& referenceSet,
188  const size_t leafSize,
189  const double tau,
190  const double rho);
191 };
192 
196 class SearchModeVisitor : public boost::static_visitor<NeighborSearchMode&>
197 {
198  public:
200  template<typename NSType>
202 };
203 
207 class EpsilonVisitor : public boost::static_visitor<double&>
208 {
209  public:
211  template<typename NSType>
212  double& operator()(NSType *ns) const;
213 };
214 
218 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
219 {
220  public:
222  template<typename NSType>
223  const arma::mat& operator()(NSType *ns) const;
224 };
225 
229 class DeleteVisitor : public boost::static_visitor<void>
230 {
231  public:
233  template<typename NSType>
234  void operator()(NSType *ns) const;
235 };
236 
247 template<typename SortPolicy>
248 class NSModel
249 {
250  public:
253  {
269  };
270 
271  private:
273  TreeTypes treeType;
274 
276  size_t leafSize;
277 
279  double tau;
281  double rho;
282 
284  bool randomBasis;
286  arma::mat q;
287 
293  boost::variant<NSType<SortPolicy, tree::KDTree>*,
305  SpillKNN*,
308 
309  public:
318  NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
319 
325  NSModel(const NSModel& other);
326 
332  NSModel(NSModel&& other);
333 
339  NSModel& operator=(const NSModel& other);
340 
346  NSModel& operator=(NSModel&& other);
347 
349  ~NSModel();
350 
352  template<typename Archive>
353  void serialize(Archive& ar, const unsigned int /* version */);
354 
356  const arma::mat& Dataset() const;
357 
361 
363  double Epsilon() const;
364  double& Epsilon();
365 
367  size_t LeafSize() const { return leafSize; }
368  size_t& LeafSize() { return leafSize; }
369 
371  double Tau() const { return tau; }
372  double& Tau() { return tau; }
373 
375  double Rho() const { return rho; }
376  double& Rho() { return rho; }
377 
379  TreeTypes TreeType() const { return treeType; }
380  TreeTypes& TreeType() { return treeType; }
381 
383  bool RandomBasis() const { return randomBasis; }
384  bool& RandomBasis() { return randomBasis; }
385 
387  void BuildModel(arma::mat&& referenceSet,
388  const size_t leafSize,
389  const NeighborSearchMode searchMode,
390  const double epsilon = 0);
391 
393  void Search(arma::mat&& querySet,
394  const size_t k,
395  arma::Mat<size_t>& neighbors,
396  arma::mat& distances);
397 
399  void Search(const size_t k,
400  arma::Mat<size_t>& neighbors,
401  arma::mat& distances);
402 
404  std::string TreeName() const;
405 };
406 
407 } // namespace neighbor
408 } // namespace mlpack
409 
411 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
413 
414 // Include implementation.
415 #include "ns_model_impl.hpp"
416 
417 #endif
double Epsilon() const
Expose Epsilon.
void operator()(NSTypeT< TreeType > *ns) const
Default Bichromatic neighbor search on the given NSType instance.
MonoSearchVisitor(const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Construct the MonoSearchVisitor object with the given parameters.
Definition: ns_model.hpp:64
EpsilonVisitor exposes the Epsilon method of the given NSType.
Definition: ns_model.hpp:207
TrainVisitor(arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)
Construct the TrainVisitor object with the given reference set, leafSize for BinarySpaceTrees, and tau and rho for spill trees.
Linear algebra utility functions, generally performed on matrices or vectors.
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:383
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:252
const arma::mat & operator()(NSType *ns) const
Return the reference set.
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::NSModel< SortPolicy >, 1)
Set the serialization version of the NSModel class.
ReferenceSetVisitor exposes the referenceSet of the given NSType.
Definition: ns_model.hpp:218
SearchModeVisitor exposes the SearchMode() method of the given NSType.
Definition: ns_model.hpp:196
void operator()(NSTypeT< TreeType > *ns) const
Default Train on the given NSType instance.
The NeighborSearch class is a template class for performing distance-based neighbor searches...
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:379
const arma::mat & Dataset() const
Expose the dataset.
NeighborSearch< SortPolicy, metric::EuclideanDistance, arma::mat, TreeType, TreeType< metric::EuclideanDistance, NeighborSearchStat< SortPolicy >, arma::mat >::template DualTreeTraverser > NSType
Alias template for euclidean neighbor search.
Definition: ns_model.hpp:42
void serialize(Archive &ar, const unsigned int)
Serialize the neighbor search model.
~NSModel()
Clean memory, if necessary.
size_t LeafSize() const
Expose leafSize.
Definition: ns_model.hpp:367
BiSearchVisitor executes a bichromatic neighbor search on the given NSType.
Definition: ns_model.hpp:80
NSModel & operator=(const NSModel &other)
Copy the given NSModel.
std::string TreeName() const
Return a string representation of the current tree type.
TreeTypes & TreeType()
Definition: ns_model.hpp:380
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:248
double Tau() const
Expose tau.
Definition: ns_model.hpp:371
NSModel(TreeTypes treeType=TreeTypes::KD_TREE, bool randomBasis=false)
Initialize the NSModel with the given type and whether or not a random basis should be used...
NeighborSearchMode & operator()(NSType *ns) const
Return the search mode.
TrainVisitor sets the reference set to a new reference set on the given NSType.
void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Perform neighbor search. The query set will be reordered.
BiSearchVisitor(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double tau, const double rho)
Construct the BiSearchVisitor.
void operator()(NSType *ns) const
Perform monochromatic nearest neighbor search.
MonoSearchVisitor executes a monochromatic neighbor search on the given NSType.
Definition: ns_model.hpp:48
void BuildModel(arma::mat &&referenceSet, const size_t leafSize, const NeighborSearchMode searchMode, const double epsilon=0)
Build the reference tree.
NeighborSearchMode SearchMode() const
Expose SearchMode.
DeleteVisitor deletes the given NSType instance.
Definition: ns_model.hpp:229
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
double Rho() const
Expose rho.
Definition: ns_model.hpp:375
double & operator()(NSType *ns) const
Return epsilon, the approximation parameter.
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
void operator()(NSType *ns) const
Delete the NSType object.