#ifndef ROOT_TMVA_MethodKNN
#define ROOT_TMVA_MethodKNN
#include <vector>
#include <map>
#ifndef ROOT_TMVA_MethodBase
#include "TMVA/MethodBase.h"
#endif
#ifndef ROOT_TMVA_ModulekNN
#include "TMVA/ModulekNN.h"
#endif
#ifndef ROOT_TMVA_LDA
#include "TMVA/LDA.h"
#endif
namespace TMVA
{
namespace kNN
{
class ModulekNN;
}
class MethodKNN : public MethodBase
{
public:
MethodKNN(const TString& jobName,
const TString& methodTitle,
DataSetInfo& theData,
const TString& theOption = "KNN",
TDirectory* theTargetDir = NULL);
MethodKNN(DataSetInfo& theData,
const TString& theWeightFile,
TDirectory* theTargetDir = NULL);
virtual ~MethodKNN( void );
virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
void Train( void );
Double_t GetMvaValue( Double_t* err = 0 );
const std::vector<Float_t>& GetRegressionValues();
using MethodBase::WriteWeightsToStream;
using MethodBase::ReadWeightsFromStream;
void WriteWeightsToStream(std::ostream& o) const;
void WriteWeightsToStream(TFile& rf) const;
void AddWeightsXMLTo( void* parent ) const;
void ReadWeightsFromXML( void* wghtnode );
void ReadWeightsFromStream(std::istream& istr);
void ReadWeightsFromStream(TFile &rf);
const Ranking* CreateRanking();
protected:
void MakeClassSpecific( std::ostream&, const TString& ) const;
void GetHelpMessage() const;
private:
void DeclareOptions();
void ProcessOptions();
void Init( void );
void MakeKNN( void );
Double_t PolnKernel(Double_t value) const;
Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector<Double_t> &svec) const;
Double_t getKernelRadius(const kNN::List &rlist) const;
const std::vector<Double_t> getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const;
double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn);
private:
Double_t fSumOfWeightsS;
Double_t fSumOfWeightsB;
kNN::ModulekNN *fModule;
Int_t fnkNN;
Int_t fBalanceDepth;
Float_t fScaleFrac;
Float_t fSigmaFact;
TString fKernel;
Bool_t fTrim;
Bool_t fUseKernel;
Bool_t fUseWeight;
Bool_t fUseLDA;
kNN::EventVec fEvent;
LDA fLDA;
ClassDef(MethodKNN,0)
};
}
#endif // MethodKNN