forestTrain

PURPOSE ^

Train random forest classifier.

SYNOPSIS ^

function forest = forestTrain( data, hs, varargin )

DESCRIPTION ^

 Train random forest classifier.

 Dimensions:
  M - number trees
  F - number features
  N - number input vectors
  H - number classes

 USAGE
  forest = forestTrain( data, hs, [varargin] )

 INPUTS
  data     - [NxF] N length F feature vectors
  hs       - [Nx1] or {Nx1} target output labels in [1,H]
  varargin - additional params (struct or name/value pairs)
   .M          - [1] number of trees to train
   .H          - [max(hs)] number of classes
   .N1         - [5*N/M] number of data points for training each tree
   .F1         - [sqrt(F)] number features to sample for each node split
   .split      - ['gini'] options include 'gini', 'entropy' and 'twoing'
   .minCount   - [1] minimum number of data points to allow split
   .minChild   - [1] minimum number of data points allowed at child nodes
   .maxDepth   - [64] maximum depth of tree
   .dWts       - [] weights used for sampling and weighing each data point
   .fWts       - [] weights used for sampling features
   .discretize - [] optional function mapping structured to class labels
                    format: [hsClass,hBest] = discretize(hsStructured,H);

 OUTPUTS
  forest   - learned forest model struct array w the following fields
   .fids     - [Kx1] feature ids for each node
   .thrs     - [Kx1] threshold corresponding to each fid
   .child    - [Kx1] index of child for each node
   .distr    - [KxH] prob distribution at each node
   .hs       - [Kx1] or {Kx1} most likely label at each node
   .count    - [Kx1] number of data points at each node
   .depth    - [Kx1] depth of each node

 EXAMPLE
  N=10000; H=5; d=2; [xs0,hs0,xs1,hs1]=demoGenData(N,N,H,d,1,1);
  xs0=single(xs0); xs1=single(xs1);
  pTrain={'maxDepth',50,'F1',2,'M',150,'minChild',5};
  tic, forest=forestTrain(xs0,hs0,pTrain{:}); toc
  hsPr0 = forestApply(xs0,forest);
  hsPr1 = forestApply(xs1,forest);
  e0=mean(hsPr0~=hs0); e1=mean(hsPr1~=hs1);
  fprintf('errors trn=%f tst=%f\n',e0,e1); figure(1);
  subplot(2,2,1); visualizeData(xs0,2,hs0);
  subplot(2,2,2); visualizeData(xs0,2,hsPr0);
  subplot(2,2,3); visualizeData(xs1,2,hs1);
  subplot(2,2,4); visualizeData(xs1,2,hsPr1);

 See also forestApply, fernsClfTrain

 Piotr's Computer Vision Matlab Toolbox      Version 3.24
 Copyright 2014 Piotr Dollar.  [pdollar-at-gmail.com]
 Licensed under the Simplified BSD License [see external/bsd.txt]

Generated by m2html © 2003