[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
39 
40 #include <iostream>
41 #include <algorithm>
42 #include <map>
43 #include <set>
44 #include <list>
45 #include <numeric>
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
49 #include "matrix.hxx"
50 #include "metaprogramming.hxx"
51 #include "random.hxx"
52 #include "functorexpression.hxx"
53 #include "random_forest/rf_common.hxx"
54 #include "random_forest/rf_nodeproxy.hxx"
55 #include "random_forest/rf_split.hxx"
56 #include "random_forest/rf_decisionTree.hxx"
57 #include "random_forest/rf_visitors.hxx"
58 #include "random_forest/rf_region.hxx"
59 #include "sampling.hxx"
60 #include "random_forest/rf_preprocessing.hxx"
61 #include "random_forest/rf_online_prediction_set.hxx"
62 #include "random_forest/rf_earlystopping.hxx"
63 #include "random_forest/rf_ridge_split.hxx"
64 namespace vigra
65 {
66 
67 /** \addtogroup MachineLearning Machine Learning
68 
69  This module provides classification algorithms that map
70  features to labels or label probabilities.
71  Look at the RandomForest class first for a overview of most of the
72  functionality provided as well as use cases.
73 **/
74 //@{
75 
76 namespace detail
77 {
78 
79 
80 
81 /* \brief sampling option factory function
82  */
83 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
84 {
85  SamplerOptions return_opt;
86  return_opt.withReplacement(RF_opt.sample_with_replacement_);
87  return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
88  return return_opt;
89 }
90 }//namespace detail
91 
92 /** Random Forest class
93  *
94  * \tparam <LabelType = double> Type used for predicted labels.
95  * \tparam <PreprocessorTag = ClassificationTag> Class used to preprocess
96  * the input while learning and predicting. Currently Available:
97  * ClassificationTag and RegressionTag. It is recommended to use
98  * Splitfunctor::Preprocessor_t while using custom splitfunctors
99  * as they may need the data to be in a different format.
100  * \sa Preprocessor
101  *
102  * Simple usage for classification (regression is not yet supported):
103  * look at RandomForest::learn() as well as RandomForestOptions() for additional
104  * options.
105  *
106  * \code
107  * using namespace vigra;
108  * using namespace rf;
109  * typedef xxx feature_t; \\ replace xxx with whichever type
110  * typedef yyy label_t; \\ likewise
111  *
112  * // allocate the training data
113  * MultiArrayView<2, feature_t> f = get_training_features();
114  * MultiArrayView<2, label_t> l = get_training_labels();
115  *
116  * RandomForest<label_t> rf;
117  *
118  * // construct visitor to calculate out-of-bag error
119  * visitors::OOB_Error oob_v;
120  *
121  * // perform training
122  * rf.learn(f, l, visitors::create_visitor(oob_v));
123  *
124  * std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n";
125  *
126  * // get features for new data to be used for prediction
127  * MultiArrayView<2, feature_t> pf = get_features();
128  *
129  * // allocate space for the response (pf.shape(0) is the number of samples)
130  * MultiArrayView<2, label_t> prediction(pf.shape(0), 1);
131  * MultiArrayView<2, double> prob(pf.shape(0), rf.class_count());
132  *
133  * // perform prediction on new data
134  * rf.predictLabels(pf, prediction);
135  * rf.predictProbabilities(pf, prob);
136  *
137  * \endcode
138  *
139  * Additional information such as Variable Importance measures are accessed
140  * via Visitors defined in rf::visitors.
141  * Have a look at rf::split for other splitting methods.
142  *
143 */
144 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
146 {
147 
148  public:
149  //public typedefs
151  typedef detail::DecisionTree DecisionTree_t;
153  typedef GiniSplit Default_Split_t;
157  StackEntry_t;
158  typedef LabelType LabelT;
159 
160  //problem independent data.
161  Options_t options_;
162  //problem dependent data members - is only set if
163  //a copy constructor, some sort of import
164  //function or the learn function is called
166  ProblemSpec_t ext_param_;
167  /*mutable ArrayVector<int> tree_indices_;*/
168  rf::visitors::OnlineLearnVisitor online_visitor_;
169 
170 
171  void reset()
172  {
173  ext_param_.clear();
174  trees_.clear();
175  }
176 
177  public:
178 
179  /** \name Constructors
180  * Note: No copy Constructor specified as no pointers are manipulated
181  * in this class
182  */
183  /*\{*/
184  /**\brief default constructor
185  *
186  * \param options general options to the Random Forest. Must be of Type
187  * Options_t
188  * \param ext_param problem specific values that can be supplied
189  * additionally. (class weights , labels etc)
190  * \sa RandomForestOptions, ProblemSpec
191  *
192  */
195  :
196  options_(options),
197  ext_param_(ext_param)/*,
198  tree_indices_(options.tree_count_,0)*/
199  {
200  /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
201  tree_indices_[ii] = ii;*/
202  }
203 
204  /**\brief Create RF from external source
205  * \param treeCount Number of trees to add.
206  * \param topology_begin
207  * Iterator to a Container where the topology_ data
208  * of the trees are stored.
209  * Iterator should support at least treeCount forward
210  * iterations. (i.e. topology_end - topology_begin >= treeCount
211  * \param parameter_begin
212  * iterator to a Container where the parameters_ data
213  * of the trees are stored. Iterator should support at
214  * least treeCount forward iterations.
215  * \param problem_spec
216  * Extrinsic parameters that specify the problem e.g.
217  * ClassCount, featureCount etc.
218  * \param options (optional) specify options used to train the original
219  * Random forest. This parameter is not used anywhere
220  * during prediction and thus is optional.
221  *
222  */
223  /* TODO: This constructor may be replaced by a Constructor using
224  * NodeProxy iterators to encapsulate the underlying data type.
225  */
226  template<class TopologyIterator, class ParameterIterator>
227  RandomForest(int treeCount,
228  TopologyIterator topology_begin,
229  ParameterIterator parameter_begin,
230  ProblemSpec_t const & problem_spec,
231  Options_t const & options = Options_t())
232  :
233  trees_(treeCount, DecisionTree_t(problem_spec)),
234  ext_param_(problem_spec),
235  options_(options)
236  {
237  for(int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
238  {
239  trees_[k].topology_ = *topology_begin;
240  trees_[k].parameters_ = *parameter_begin;
241  }
242  }
243 
244  /*\}*/
245 
246 
247  /** \name Data Access
248  * data access interface - usage of member variables is deprecated
249  */
250 
251  /*\{*/
252 
253 
254  /**\brief return external parameters for viewing
255  * \return ProblemSpec_t
256  */
257  ProblemSpec_t const & ext_param() const
258  {
259  vigra_precondition(ext_param_.used() == true,
260  "RandomForest::ext_param(): "
261  "Random forest has not been trained yet.");
262  return ext_param_;
263  }
264 
265  /**\brief set external parameters
266  *
267  * \param in external parameters to be set
268  *
269  * set external parameters explicitly.
270  * If Random Forest has not been trained the preprocessor will
271  * either ignore filling values set this way or will throw an exception
272  * if values specified manually do not match the value calculated
273  & during the preparation step.
274  */
275  void set_ext_param(ProblemSpec_t const & in)
276  {
277  ignore_argument(in);
278  vigra_precondition(ext_param_.used() == false,
279  "RandomForest::set_ext_param():"
280  "Random forest has been trained! Call reset()"
281  "before specifying new extrinsic parameters.");
282  }
283 
284  /**\brief access random forest options
285  *
286  * \return random forest options
287  */
289  {
290  return options_;
291  }
292 
293 
294  /**\brief access const random forest options
295  *
296  * \return const Option_t
297  */
298  Options_t const & options() const
299  {
300  return options_;
301  }
302 
303  /**\brief access const trees
304  */
305  DecisionTree_t const & tree(int index) const
306  {
307  return trees_[index];
308  }
309 
310  /**\brief access trees
311  */
312  DecisionTree_t & tree(int index)
313  {
314  return trees_[index];
315  }
316 
317  /*\}*/
318 
319  /**\brief return number of features used while
320  * training.
321  */
322  int feature_count() const
323  {
324  return ext_param_.column_count_;
325  }
326 
327 
328  /**\brief return number of features used while
329  * training.
330  *
331  * deprecated. Use feature_count() instead.
332  */
333  int column_count() const
334  {
335  return ext_param_.column_count_;
336  }
337 
338  /**\brief return number of classes used while
339  * training.
340  */
341  int class_count() const
342  {
343  return ext_param_.class_count_;
344  }
345 
346  /**\brief return number of trees
347  */
348  int tree_count() const
349  {
350  return options_.tree_count_;
351  }
352 
353 
354 
355  template<class U,class C1,
356  class U2, class C2,
357  class Split_t,
358  class Stop_t,
359  class Visitor_t,
360  class Random_t>
361  void onlineLearn( MultiArrayView<2,U,C1> const & features,
362  MultiArrayView<2,U2,C2> const & response,
363  int new_start_index,
364  Visitor_t visitor_,
365  Split_t split_,
366  Stop_t stop_,
367  Random_t & random,
368  bool adjust_thresholds=false);
369 
370  template <class U, class C1, class U2,class C2>
371  void onlineLearn( MultiArrayView<2, U, C1> const & features,
372  MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)
373  {
375  onlineLearn(features,
376  labels,
377  new_start_index,
378  rf_default(),
379  rf_default(),
380  rf_default(),
381  rnd,
382  adjust_thresholds);
383  }
384 
385  template<class U,class C1,
386  class U2, class C2,
387  class Split_t,
388  class Stop_t,
389  class Visitor_t,
390  class Random_t>
391  void reLearnTree(MultiArrayView<2,U,C1> const & features,
392  MultiArrayView<2,U2,C2> const & response,
393  int treeId,
394  Visitor_t visitor_,
395  Split_t split_,
396  Stop_t stop_,
397  Random_t & random);
398 
399  template<class U, class C1, class U2, class C2>
400  void reLearnTree(MultiArrayView<2, U, C1> const & features,
401  MultiArrayView<2, U2, C2> const & labels,
402  int treeId)
403  {
404  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
405  reLearnTree(features,
406  labels,
407  treeId,
408  rf_default(),
409  rf_default(),
410  rf_default(),
411  rnd);
412  }
413 
414 
415  /**\name Learning
416  * Following functions differ in the degree of customization
417  * allowed
418  */
419  /*\{*/
420  /**\brief learn on data with custom config and random number generator
421  *
422  * \param features a N x M matrix containing N samples with M
423  * features
424  * \param response a N x D matrix containing the corresponding
425  * response. Current split functors assume D to
426  * be 1 and ignore any additional columns.
427  * This is not enforced to allow future support
428  * for uncertain labels, label independent strata etc.
429  * The Preprocessor specified during construction
430  * should be able to handle features and labels
431  * features and the labels.
432  * see also: SplitFunctor, Preprocessing
433  *
434  * \param visitor visitor which is to be applied after each split,
435  * tree and at the end. Use rf_default() for using
436  * default value. (No Visitors)
437  * see also: rf::visitors
438  * \param split split functor to be used to calculate each split
439  * use rf_default() for using default value. (GiniSplit)
440  * see also: rf::split
441  * \param stop
442  * predicate to be used to calculate each split
443  * use rf_default() for using default value. (EarlyStoppStd)
444  * \param random RandomNumberGenerator to be used. Use
445  * rf_default() to use default value.(RandomMT19337)
446  *
447  *
448  */
449  template <class U, class C1,
450  class U2,class C2,
451  class Split_t,
452  class Stop_t,
453  class Visitor_t,
454  class Random_t>
455  void learn( MultiArrayView<2, U, C1> const & features,
456  MultiArrayView<2, U2,C2> const & response,
457  Visitor_t visitor,
458  Split_t split,
459  Stop_t stop,
460  Random_t const & random);
461 
462  template <class U, class C1,
463  class U2,class C2,
464  class Split_t,
465  class Stop_t,
466  class Visitor_t>
467  void learn( MultiArrayView<2, U, C1> const & features,
468  MultiArrayView<2, U2,C2> const & response,
469  Visitor_t visitor,
470  Split_t split,
471  Stop_t stop)
472 
473  {
474  RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
475  learn( features,
476  response,
477  visitor,
478  split,
479  stop,
480  rnd);
481  }
482 
483  template <class U, class C1, class U2,class C2, class Visitor_t>
484  void learn( MultiArrayView<2, U, C1> const & features,
485  MultiArrayView<2, U2,C2> const & labels,
486  Visitor_t visitor)
487  {
488  learn( features,
489  labels,
490  visitor,
491  rf_default(),
492  rf_default());
493  }
494 
495  template <class U, class C1, class U2,class C2,
496  class Visitor_t, class Split_t>
497  void learn( MultiArrayView<2, U, C1> const & features,
498  MultiArrayView<2, U2,C2> const & labels,
499  Visitor_t visitor,
500  Split_t split)
501  {
502  learn( features,
503  labels,
504  visitor,
505  split,
506  rf_default());
507  }
508 
509  /**\brief learn on data with default configuration
510  *
511  * \param features a N x M matrix containing N samples with M
512  * features
513  * \param labels a N x D matrix containing the corresponding
514  * N labels. Current split functors assume D to
515  * be 1 and ignore any additional columns.
516  * this is not enforced to allow future support
517  * for uncertain labels.
518  *
519  * learning is done with:
520  *
521  * \sa rf::split, EarlyStoppStd
522  *
523  * - Randomly seeded random number generator
524  * - default gini split functor as described by Breiman
525  * - default The standard early stopping criterion
526  */
527  template <class U, class C1, class U2,class C2>
528  void learn( MultiArrayView<2, U, C1> const & features,
529  MultiArrayView<2, U2,C2> const & labels)
530  {
531  learn( features,
532  labels,
533  rf_default(),
534  rf_default(),
535  rf_default());
536  }
537  /*\}*/
538 
539 
540 
541  /**\name prediction
542  */
543  /*\{*/
544  /** \brief predict a label given a feature.
545  *
546  * \param features: a 1 by featureCount matrix containing
547  * data point to be predicted (this only works in
548  * classification setting)
549  * \param stop: early stopping criterion
550  * \return double value representing class. You can use the
551  * predictLabels() function together with the
552  * rf.external_parameter().class_type_ attribute
553  * to get back the same type used during learning.
554  */
555  template <class U, class C, class Stop>
556  LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
557 
558  template <class U, class C>
559  LabelType predictLabel(MultiArrayView<2, U, C>const & features)
560  {
561  return predictLabel(features, rf_default());
562  }
563  /** \brief predict a label with features and class priors
564  *
565  * \param features: same as above.
566  * \param prior: iterator to prior weighting of classes
567  * \return sam as above.
568  */
569  template <class U, class C>
570  LabelType predictLabel(MultiArrayView<2, U, C> const & features,
571  ArrayVectorView<double> prior) const;
572 
573  /** \brief predict multiple labels with given features
574  *
575  * \param features: a n by featureCount matrix containing
576  * data point to be predicted (this only works in
577  * classification setting)
578  * \param labels: a n by 1 matrix passed by reference to store
579  * output.
580  *
581  * If the input contains an NaN value, an precondition exception is thrown.
582  */
583  template <class U, class C1, class T, class C2>
585  MultiArrayView<2, T, C2> & labels) const
586  {
587  vigra_precondition(features.shape(0) == labels.shape(0),
588  "RandomForest::predictLabels(): Label array has wrong size.");
589  for(int k=0; k<features.shape(0); ++k)
590  {
591  vigra_precondition(!detail::contains_nan(rowVector(features, k)),
592  "RandomForest::predictLabels(): NaN in feature matrix.");
593  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
594  }
595  }
596 
597  /** \brief predict multiple labels with given features
598  *
599  * \param features: a n by featureCount matrix containing
600  * data point to be predicted (this only works in
601  * classification setting)
602  * \param labels: a n by 1 matrix passed by reference to store
603  * output.
604  * \param nanLabel: label to be returned for the row of the input that
605  * contain an NaN value.
606  */
607  template <class U, class C1, class T, class C2>
609  MultiArrayView<2, T, C2> & labels,
610  LabelType nanLabel) const
611  {
612  vigra_precondition(features.shape(0) == labels.shape(0),
613  "RandomForest::predictLabels(): Label array has wrong size.");
614  for(int k=0; k<features.shape(0); ++k)
615  {
616  if(detail::contains_nan(rowVector(features, k)))
617  labels(k,0) = nanLabel;
618  else
619  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
620  }
621  }
622 
623  /** \brief predict multiple labels with given features
624  *
625  * \param features: a n by featureCount matrix containing
626  * data point to be predicted (this only works in
627  * classification setting)
628  * \param labels: a n by 1 matrix passed by reference to store
629  * output.
630  * \param stop: an early stopping criterion.
631  */
632  template <class U, class C1, class T, class C2, class Stop>
634  MultiArrayView<2, T, C2> & labels,
635  Stop & stop) const
636  {
637  vigra_precondition(features.shape(0) == labels.shape(0),
638  "RandomForest::predictLabels(): Label array has wrong size.");
639  for(int k=0; k<features.shape(0); ++k)
640  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
641  }
642  /** \brief predict the class probabilities for multiple labels
643  *
644  * \param features same as above
645  * \param prob a n x class_count_ matrix. passed by reference to
646  * save class probabilities
647  * \param stop earlystopping criterion
648  * \sa EarlyStopping
649 
650  When a row of the feature array contains an NaN, the corresponding instance
651  cannot belong to any of the classes. The corresponding row in the probability
652  array will therefore contain all zeros.
653  */
654  template <class U, class C1, class T, class C2, class Stop>
655  void predictProbabilities(MultiArrayView<2, U, C1>const & features,
657  Stop & stop) const;
658  template <class T1,class T2, class C>
659  void predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
660  MultiArrayView<2, T2, C> & prob);
661 
662  /** \brief predict the class probabilities for multiple labels
663  *
664  * \param features same as above
665  * \param prob a n x class_count_ matrix. passed by reference to
666  * save class probabilities
667  */
668  template <class U, class C1, class T, class C2>
670  MultiArrayView<2, T, C2> & prob) const
671  {
672  predictProbabilities(features, prob, rf_default());
673  }
674 
675  template <class U, class C1, class T, class C2>
676  void predictRaw(MultiArrayView<2, U, C1>const & features,
677  MultiArrayView<2, T, C2> & prob) const;
678 
679 
680  /*\}*/
681 
682 };
683 
684 
685 template <class LabelType, class PreprocessorTag>
686 template<class U,class C1,
687  class U2, class C2,
688  class Split_t,
689  class Stop_t,
690  class Visitor_t,
691  class Random_t>
692 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
693  MultiArrayView<2,U2,C2> const & response,
694  int new_start_index,
695  Visitor_t visitor_,
696  Split_t split_,
697  Stop_t stop_,
698  Random_t & random,
699  bool adjust_thresholds)
700 {
701  online_visitor_.activate();
702  online_visitor_.adjust_thresholds=adjust_thresholds;
703 
704  using namespace rf;
705  //typedefs
706  typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
707  typedef UniformIntRandomFunctor<Random_t>
708  RandFunctor_t;
709  // default values and initialization
710  // Value Chooser chooses second argument as value if first argument
711  // is of type RF_DEFAULT. (thanks to template magic - don't care about
712  // it - just smile and wave.
713 
714  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
715  Default_Stop_t default_stop(options_);
716  typename RF_CHOOSER(Stop_t)::type stop
717  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
718  Default_Split_t default_split;
719  typename RF_CHOOSER(Split_t)::type split
720  = RF_CHOOSER(Split_t)::choose(split_, default_split);
721  rf::visitors::StopVisiting stopvisiting;
722  typedef rf::visitors::detail::VisitorNode
723  <rf::visitors::OnlineLearnVisitor,
724  typename RF_CHOOSER(Visitor_t)::type>
725  IntermedVis;
726  IntermedVis
727  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
728  #undef RF_CHOOSER
729  vigra_precondition(options_.prepare_online_learning_,"onlineLearn: online learning must be enabled on RandomForest construction");
730 
731  // Preprocess the data to get something the split functor can work
732  // with. Also fill the ext_param structure by preprocessing
733  // option parameters that could only be completely evaluated
734  // when the training data is known.
735  ext_param_.class_count_=0;
736  Preprocessor_t preprocessor( features, response,
737  options_, ext_param_);
738 
739  // Make stl compatible random functor.
740  RandFunctor_t randint ( random);
741 
742  // Give the Split functor information about the data.
743  split.set_external_parameters(ext_param_);
744  stop.set_external_parameters(ext_param_);
745 
746 
747  //Create poisson samples
748  PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
749 
750  //TODO: visitors for online learning
751  //visitor.visit_at_beginning(*this, preprocessor);
752 
753  // THE MAIN EFFING RF LOOP - YEAY DUDE!
754  for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
755  {
756  online_visitor_.tree_id=ii;
757  poisson_sampler.sample();
758  std::map<int,int> leaf_parents;
759  leaf_parents.clear();
760  //Get all the leaf nodes for that sample
761  for(int s=0;s<poisson_sampler.numOfSamples();++s)
762  {
763  int sample=poisson_sampler[s];
764  online_visitor_.current_label=preprocessor.response()(sample,0);
765  online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
766  int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
767 
768 
769  //Add to the list for that leaf
770  online_visitor_.add_to_index_list(ii,leaf,sample);
771  //TODO: Class count?
772  //Store parent
773  if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
774  {
775  leaf_parents[leaf]=online_visitor_.last_node_id;
776  }
777  }
778 
779 
780  std::map<int,int>::iterator leaf_iterator;
781  for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
782  {
783  int leaf=leaf_iterator->first;
784  int parent=leaf_iterator->second;
785  int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
786  ArrayVector<Int32> indeces;
787  indeces.clear();
788  indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
789  StackEntry_t stack_entry(indeces.begin(),
790  indeces.end(),
791  ext_param_.class_count_);
792 
793 
794  if(parent!=-1)
795  {
796  if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
797  {
798  stack_entry.leftParent=parent;
799  }
800  else
801  {
802  vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
803  stack_entry.rightParent=parent;
804  }
805  }
806  //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
807  trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
808  //Now, the last one moved onto leaf
809  online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
810  //Now it should be classified correctly!
811  }
812 
813  /*visitor
814  .visit_after_tree( *this,
815  preprocessor,
816  poisson_sampler,
817  stack_entry,
818  ii);*/
819  }
820 
821  //visitor.visit_at_end(*this, preprocessor);
822  online_visitor_.deactivate();
823 }
824 
825 template<class LabelType, class PreprocessorTag>
826 template<class U,class C1,
827  class U2, class C2,
828  class Split_t,
829  class Stop_t,
830  class Visitor_t,
831  class Random_t>
833  MultiArrayView<2,U2,C2> const & response,
834  int treeId,
835  Visitor_t visitor_,
836  Split_t split_,
837  Stop_t stop_,
838  Random_t & random)
839 {
840  using namespace rf;
841 
842 
844  RandFunctor_t;
845 
846  // See rf_preprocessing.hxx for more info on this
847  ext_param_.class_count_=0;
849 
850  // default values and initialization
851  // Value Chooser chooses second argument as value if first argument
852  // is of type RF_DEFAULT. (thanks to template magic - don't care about
853  // it - just smile and wave.
854 
855  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
856  Default_Stop_t default_stop(options_);
857  typename RF_CHOOSER(Stop_t)::type stop
858  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
859  Default_Split_t default_split;
860  typename RF_CHOOSER(Split_t)::type split
861  = RF_CHOOSER(Split_t)::choose(split_, default_split);
862  rf::visitors::StopVisiting stopvisiting;
865  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
866  IntermedVis
867  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
868  #undef RF_CHOOSER
869  vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
870  online_visitor_.activate();
871 
872  // Make stl compatible random functor.
873  RandFunctor_t randint ( random);
874 
875  // Preprocess the data to get something the split functor can work
876  // with. Also fill the ext_param structure by preprocessing
877  // option parameters that could only be completely evaluated
878  // when the training data is known.
879  Preprocessor_t preprocessor( features, response,
880  options_, ext_param_);
881 
882  // Give the Split functor information about the data.
883  split.set_external_parameters(ext_param_);
884  stop.set_external_parameters(ext_param_);
885 
886  /**\todo replace this crappy class out. It uses function pointers.
887  * and is making code slower according to me.
888  * Comment from Nathan: This is copied from Rahul, so me=Rahul
889  */
890  Sampler<Random_t > sampler(preprocessor.strata().begin(),
891  preprocessor.strata().end(),
892  detail::make_sampler_opt(options_)
893  .sampleSize(ext_param().actual_msample_),
894  &random);
895  //initialize First region/node/stack entry
896  sampler
897  .sample();
898 
900  first_stack_entry( sampler.sampledIndices().begin(),
901  sampler.sampledIndices().end(),
902  ext_param_.class_count_);
903  first_stack_entry
904  .set_oob_range( sampler.oobIndices().begin(),
905  sampler.oobIndices().end());
906  online_visitor_.reset_tree(treeId);
907  online_visitor_.tree_id=treeId;
908  trees_[treeId].reset();
909  trees_[treeId]
910  .learn( preprocessor.features(),
911  preprocessor.response(),
912  first_stack_entry,
913  split,
914  stop,
915  visitor,
916  randint);
917  visitor
918  .visit_after_tree( *this,
919  preprocessor,
920  sampler,
921  first_stack_entry,
922  treeId);
923 
924  online_visitor_.deactivate();
925 }
926 
927 template <class LabelType, class PreprocessorTag>
928 template <class U, class C1,
929  class U2,class C2,
930  class Split_t,
931  class Stop_t,
932  class Visitor_t,
933  class Random_t>
936  MultiArrayView<2, U2,C2> const & response,
937  Visitor_t visitor_,
938  Split_t split_,
939  Stop_t stop_,
940  Random_t const & random)
941 {
942  using namespace rf;
943  //this->reset();
944  //typedefs
946  RandFunctor_t;
947 
948  // See rf_preprocessing.hxx for more info on this
950 
951  vigra_precondition(features.shape(0) == response.shape(0),
952  "RandomForest::learn(): shape mismatch between features and response.");
953 
954  // default values and initialization
955  // Value Chooser chooses second argument as value if first argument
956  // is of type RF_DEFAULT. (thanks to template magic - don't care about
957  // it - just smile and wave).
958 
959  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
960  Default_Stop_t default_stop(options_);
961  typename RF_CHOOSER(Stop_t)::type stop
962  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
963  Default_Split_t default_split;
964  typename RF_CHOOSER(Split_t)::type split
965  = RF_CHOOSER(Split_t)::choose(split_, default_split);
966  rf::visitors::StopVisiting stopvisiting;
969  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
970  IntermedVis
971  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
972  #undef RF_CHOOSER
973  if(options_.prepare_online_learning_)
974  online_visitor_.activate();
975  else
976  online_visitor_.deactivate();
977 
978 
979  // Make stl compatible random functor.
980  RandFunctor_t randint ( random);
981 
982 
983  // Preprocess the data to get something the split functor can work
984  // with. Also fill the ext_param structure by preprocessing
985  // option parameters that could only be completely evaluated
986  // when the training data is known.
987  Preprocessor_t preprocessor( features, response,
988  options_, ext_param_);
989 
990  // Give the Split functor information about the data.
991  split.set_external_parameters(ext_param_);
992  stop.set_external_parameters(ext_param_);
993 
994 
995  //initialize trees.
996  trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
997 
998  Sampler<Random_t > sampler(preprocessor.strata().begin(),
999  preprocessor.strata().end(),
1000  detail::make_sampler_opt(options_)
1001  .sampleSize(ext_param().actual_msample_),
1002  &random);
1003 
1004  visitor.visit_at_beginning(*this, preprocessor);
1005  // THE MAIN EFFING RF LOOP - YEAY DUDE!
1006 
1007  for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1008  {
1009  //initialize First region/node/stack entry
1010  sampler
1011  .sample();
1012  StackEntry_t
1013  first_stack_entry( sampler.sampledIndices().begin(),
1014  sampler.sampledIndices().end(),
1015  ext_param_.class_count_);
1016  first_stack_entry
1017  .set_oob_range( sampler.oobIndices().begin(),
1018  sampler.oobIndices().end());
1019  trees_[ii]
1020  .learn( preprocessor.features(),
1021  preprocessor.response(),
1022  first_stack_entry,
1023  split,
1024  stop,
1025  visitor,
1026  randint);
1027  visitor
1028  .visit_after_tree( *this,
1029  preprocessor,
1030  sampler,
1031  first_stack_entry,
1032  ii);
1033  }
1034 
1035  visitor.visit_at_end(*this, preprocessor);
1036  // Only for online learning?
1037  online_visitor_.deactivate();
1038 }
1039 
1040 
1041 
1042 
1043 template <class LabelType, class Tag>
1044 template <class U, class C, class Stop>
1046  ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
1047 {
1048  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1049  "RandomForestn::predictLabel():"
1050  " Too few columns in feature matrix.");
1051  vigra_precondition(rowCount(features) == 1,
1052  "RandomForestn::predictLabel():"
1053  " Feature matrix must have a singlerow.");
1054  MultiArray<2, double> probabilities(Shape2(1, ext_param_.class_count_), 0.0);
1055  LabelType d;
1056  predictProbabilities(features, probabilities, stop);
1057  ext_param_.to_classlabel(argMax(probabilities), d);
1058  return d;
1059 }
1060 
1061 
1062 //Same thing as above with priors for each label !!!
1063 template <class LabelType, class PreprocessorTag>
1064 template <class U, class C>
1067  ArrayVectorView<double> priors) const
1068 {
1069  using namespace functor;
1070  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1071  "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1072  vigra_precondition(rowCount(features) == 1,
1073  "RandomForestn::predictLabel():"
1074  " Feature matrix must have a single row.");
1075  Matrix<double> prob(1,ext_param_.class_count_);
1076  predictProbabilities(features, prob);
1077  std::transform( prob.begin(), prob.end(),
1078  priors.begin(), prob.begin(),
1079  Arg1()*Arg2());
1080  LabelType d;
1081  ext_param_.to_classlabel(argMax(prob), d);
1082  return d;
1083 }
1084 
1085 template<class LabelType,class PreprocessorTag>
1086 template <class T1,class T2, class C>
1088  ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
1089  MultiArrayView<2, T2, C> & prob)
1090 {
1091  //Features are n xp
1092  //prob is n x NumOfLabel probability for each feature in each class
1093 
1094  vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1095  "RandomFroest::predictProbabilities():"
1096  " Feature matrix and probability matrix size mismatch.");
1097  // num of features must be bigger than num of features in Random forest training
1098  // but why bigger?
1099  vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1100  "RandomForestn::predictProbabilities():"
1101  " Too few columns in feature matrix.");
1102  vigra_precondition( columnCount(prob)
1103  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1104  "RandomForestn::predictProbabilities():"
1105  " Probability matrix must have as many columns as there are classes.");
1106  prob.init(0.0);
1107  //store total weights
1108  std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1109  //Go through all trees
1110  int set_id=-1;
1111  for(int k=0; k<options_.tree_count_; ++k)
1112  {
1113  set_id=(set_id+1) % predictionSet.indices[0].size();
1114  typedef std::set<SampleRange<T1> > my_set;
1115  typedef typename my_set::iterator set_it;
1116  //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
1117  //Build a stack with all the ranges we have
1118  std::vector<std::pair<int,set_it> > stack;
1119  stack.clear();
1120  for(set_it i=predictionSet.ranges[set_id].begin();
1121  i!=predictionSet.ranges[set_id].end();++i)
1122  stack.push_back(std::pair<int,set_it>(2,i));
1123  //get weights predicted by single tree
1124  int num_decisions=0;
1125  while(!stack.empty())
1126  {
1127  set_it range=stack.back().second;
1128  int index=stack.back().first;
1129  stack.pop_back();
1130  ++num_decisions;
1131 
1132  if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1133  {
1134  ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1135  trees_[k].parameters_,
1136  index).prob_begin();
1137  for(int i=range->start;i!=range->end;++i)
1138  {
1139  //update votecount.
1140  for(int l=0; l<ext_param_.class_count_; ++l)
1141  {
1142  prob(predictionSet.indices[set_id][i], l) += static_cast<T2>(weights[l]);
1143  //every weight in totalWeight.
1144  totalWeights[predictionSet.indices[set_id][i]] += static_cast<T1>(weights[l]);
1145  }
1146  }
1147  }
1148 
1149  else
1150  {
1151  if(trees_[k].topology_[index]!=i_ThresholdNode)
1152  {
1153  throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
1154  }
1155  Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1156  if(range->min_boundaries[node.column()]>=node.threshold())
1157  {
1158  //Everything goes to right child
1159  stack.push_back(std::pair<int,set_it>(node.child(1),range));
1160  continue;
1161  }
1162  if(range->max_boundaries[node.column()]<node.threshold())
1163  {
1164  //Everything goes to the left child
1165  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1166  continue;
1167  }
1168  //We have to split at this node
1169  SampleRange<T1> new_range=*range;
1170  new_range.min_boundaries[node.column()]=FLT_MAX;
1171  range->max_boundaries[node.column()]=-FLT_MAX;
1172  new_range.start=new_range.end=range->end;
1173  int i=range->start;
1174  while(i!=range->end)
1175  {
1176  //Decide for range->indices[i]
1177  if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1178  {
1179  new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1180  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1181  --range->end;
1182  --new_range.start;
1183  std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1184 
1185  }
1186  else
1187  {
1188  range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1189  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1190  ++i;
1191  }
1192  }
1193  //The old one ...
1194  if(range->start==range->end)
1195  {
1196  predictionSet.ranges[set_id].erase(range);
1197  }
1198  else
1199  {
1200  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1201  }
1202  //And the new one ...
1203  if(new_range.start!=new_range.end)
1204  {
1205  std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1206  stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1207  }
1208  }
1209  }
1210  predictionSet.cumulativePredTime[k]=num_decisions;
1211  }
1212  for(unsigned int i=0;i<totalWeights.size();++i)
1213  {
1214  double test=0.0;
1215  //Normalise votes in each row by total VoteCount (totalWeight
1216  for(int l=0; l<ext_param_.class_count_; ++l)
1217  {
1218  test+=prob(i,l);
1219  prob(i, l) /= totalWeights[i];
1220  }
1221  assert(test==totalWeights[i]);
1222  assert(totalWeights[i]>0.0);
1223  }
1224 }
1225 
1226 template <class LabelType, class PreprocessorTag>
1227 template <class U, class C1, class T, class C2, class Stop_t>
1229  ::predictProbabilities(MultiArrayView<2, U, C1>const & features,
1230  MultiArrayView<2, T, C2> & prob,
1231  Stop_t & stop_) const
1232 {
1233  //Features are n xp
1234  //prob is n x NumOfLabel probability for each feature in each class
1235 
1236  vigra_precondition(rowCount(features) == rowCount(prob),
1237  "RandomForestn::predictProbabilities():"
1238  " Feature matrix and probability matrix size mismatch.");
1239 
1240  // num of features must be bigger than num of features in Random forest training
1241  // but why bigger?
1242  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1243  "RandomForestn::predictProbabilities():"
1244  " Too few columns in feature matrix.");
1245  vigra_precondition( columnCount(prob)
1246  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1247  "RandomForestn::predictProbabilities():"
1248  " Probability matrix must have as many columns as there are classes.");
1249 
1250  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1251  Default_Stop_t default_stop(options_);
1252  typename RF_CHOOSER(Stop_t)::type & stop
1253  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1254  #undef RF_CHOOSER
1255  stop.set_external_parameters(ext_param_, tree_count());
1256  prob.init(NumericTraits<T>::zero());
1257  /* This code was originally there for testing early stopping
1258  * - we wanted the order of the trees to be randomized
1259  if(tree_indices_.size() != 0)
1260  {
1261  std::random_shuffle(tree_indices_.begin(),
1262  tree_indices_.end());
1263  }
1264  */
1265  //Classify for each row.
1266  for(int row=0; row < rowCount(features); ++row)
1267  {
1268  MultiArrayView<2, U, StridedArrayTag> currentRow(rowVector(features, row));
1269 
1270  // when the features contain an NaN, the instance doesn't belong to any class
1271  // => indicate this by returning a zero probability array.
1272  if(detail::contains_nan(currentRow))
1273  {
1274  rowVector(prob, row).init(0.0);
1275  continue;
1276  }
1277 
1278  ArrayVector<double>::const_iterator weights;
1279 
1280  //totalWeight == totalVoteCount!
1281  double totalWeight = 0.0;
1282 
1283  //Let each tree classify...
1284  for(int k=0; k<options_.tree_count_; ++k)
1285  {
1286  //get weights predicted by single tree
1287  weights = trees_[k /*tree_indices_[k]*/].predict(currentRow);
1288 
1289  //update votecount.
1290  int weighted = options_.predict_weighted_;
1291  for(int l=0; l<ext_param_.class_count_; ++l)
1292  {
1293  double cur_w = weights[l] * (weighted * (*(weights-1))
1294  + (1-weighted));
1295  prob(row, l) += static_cast<T>(cur_w);
1296  //every weight in totalWeight.
1297  totalWeight += cur_w;
1298  }
1299  if(stop.after_prediction(weights,
1300  k,
1301  rowVector(prob, row),
1302  totalWeight))
1303  {
1304  break;
1305  }
1306  }
1307 
1308  //Normalise votes in each row by total VoteCount (totalWeight
1309  for(int l=0; l< ext_param_.class_count_; ++l)
1310  {
1311  prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1312  }
1313  }
1314 
1315 }
1316 
1317 template <class LabelType, class PreprocessorTag>
1318 template <class U, class C1, class T, class C2>
1319 void RandomForest<LabelType, PreprocessorTag>
1320  ::predictRaw(MultiArrayView<2, U, C1>const & features,
1321  MultiArrayView<2, T, C2> & prob) const
1322 {
1323  //Features are n xp
1324  //prob is n x NumOfLabel probability for each feature in each class
1325 
1326  vigra_precondition(rowCount(features) == rowCount(prob),
1327  "RandomForestn::predictProbabilities():"
1328  " Feature matrix and probability matrix size mismatch.");
1329 
1330  // num of features must be bigger than num of features in Random forest training
1331  // but why bigger?
1332  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1333  "RandomForestn::predictProbabilities():"
1334  " Too few columns in feature matrix.");
1335  vigra_precondition( columnCount(prob)
1336  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1337  "RandomForestn::predictProbabilities():"
1338  " Probability matrix must have as many columns as there are classes.");
1339 
1340  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1341  prob.init(NumericTraits<T>::zero());
1342  /* This code was originally there for testing early stopping
1343  * - we wanted the order of the trees to be randomized
1344  if(tree_indices_.size() != 0)
1345  {
1346  std::random_shuffle(tree_indices_.begin(),
1347  tree_indices_.end());
1348  }
1349  */
1350  //Classify for each row.
1351  for(int row=0; row < rowCount(features); ++row)
1352  {
1353  ArrayVector<double>::const_iterator weights;
1354 
1355  //totalWeight == totalVoteCount!
1356  double totalWeight = 0.0;
1357 
1358  //Let each tree classify...
1359  for(int k=0; k<options_.tree_count_; ++k)
1360  {
1361  //get weights predicted by single tree
1362  weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1363 
1364  //update votecount.
1365  int weighted = options_.predict_weighted_;
1366  for(int l=0; l<ext_param_.class_count_; ++l)
1367  {
1368  double cur_w = weights[l] * (weighted * (*(weights-1))
1369  + (1-weighted));
1370  prob(row, l) += static_cast<T>(cur_w);
1371  //every weight in totalWeight.
1372  totalWeight += cur_w;
1373  }
1374  }
1375  }
1376  prob/= options_.tree_count_;
1377 
1378 }
1379 
1380 //@}
1381 
1382 } // namespace vigra
1383 
1384 #include "random_forest/rf_algorithm.hxx"
1385 #endif // VIGRA_RANDOM_FOREST_HXX
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition: random_forest.hxx:608
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:275
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:341
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
Definition: rf_preprocessing.hxx:63
int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:322
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:333
Create random samples from a sequence of indices.
Definition: sampling.hxx:233
const difference_type & shape() const
Definition: multi_array.hxx:1594
Definition: rf_split.hxx:993
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:533
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:193
void sample()
Definition: sampling.hxx:468
Standard early stopping criterion.
Definition: rf_common.hxx:880
Definition: random.hxx:669
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:257
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:312
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:305
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:288
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:935
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:832
Definition: random_forest.hxx:145
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:298
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
Definition: rf_visitors.hxx:255
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition: random_forest.hxx:633
Definition: rf_visitors.hxx:584
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:584
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:144
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:86
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:669
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
int tree_count() const
return number of trees
Definition: random_forest.hxx:348
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:227
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:650
Options object for the random forest.
Definition: rf_common.hxx:170
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1152
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1046
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:528
Definition: rf_visitors.hxx:235

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.0 (Thu Mar 17 2016)