#include <cmath>
#include <string>
#include <cstdlib>
#include "TFile.h"
#include "TMath.h"
#include "TTree.h"
#include "TMVA/ClassifierFactory.h"
#include "TMVA/MethodKNN.h"
#include "TMVA/Ranking.h"
#include "TMVA/Tools.h"
REGISTER_METHOD(KNN)
ClassImp(TMVA::MethodKNN)
TMVA::MethodKNN::MethodKNN( const TString& jobName,
const TString& methodTitle,
DataSetInfo& theData,
const TString& theOption,
TDirectory* theTargetDir )
: TMVA::MethodBase(jobName, Types::kKNN, methodTitle, theData, theOption, theTargetDir),
fModule(0)
{
}
TMVA::MethodKNN::MethodKNN( DataSetInfo& theData,
const TString& theWeightFile,
TDirectory* theTargetDir )
: TMVA::MethodBase( Types::kKNN, theData, theWeightFile, theTargetDir),
fModule(0)
{
}
TMVA::MethodKNN::~MethodKNN()
{
if (fModule) delete fModule;
}
void TMVA::MethodKNN::DeclareOptions()
{
DeclareOptionRef(fnkNN = 20, "nkNN", "Number of k-nearest neighbors");
DeclareOptionRef(fBalanceDepth = 6, "BalanceDepth", "Binary tree balance depth");
DeclareOptionRef(fScaleFrac = 0.80, "ScaleFrac", "Fraction of events used to compute variable width");
DeclareOptionRef(fSigmaFact = 1.0, "SigmaFact", "Scale factor for sigma in Gaussian kernel");
DeclareOptionRef(fKernel = "Gaus", "Kernel", "Use polynomial (=Poln) or Gaussian (=Gaus) kernel");
DeclareOptionRef(fTrim = kFALSE, "Trim", "Use equal number of signal and background events");
DeclareOptionRef(fUseKernel = kFALSE, "UseKernel", "Use polynomial kernel weight");
DeclareOptionRef(fUseWeight = kTRUE, "UseWeight", "Use weight to count kNN events");
DeclareOptionRef(fUseLDA = kFALSE, "UseLDA", "Use local linear discriminant - experimental feature");
}
void TMVA::MethodKNN::DeclareCompatibilityOptions() {
MethodBase::DeclareCompatibilityOptions();
DeclareOptionRef(fTreeOptDepth = 6, "TreeOptDepth", "Binary tree optimisation depth");
}
void TMVA::MethodKNN::ProcessOptions()
{
if (!(fnkNN > 0)) {
fnkNN = 10;
Log() << kWARNING << "kNN must be a positive integer: set kNN = " << fnkNN << Endl;
}
if (fScaleFrac < 0.0) {
fScaleFrac = 0.0;
Log() << kWARNING << "ScaleFrac can not be negative: set ScaleFrac = " << fScaleFrac << Endl;
}
if (fScaleFrac > 1.0) {
fScaleFrac = 1.0;
}
if (!(fBalanceDepth > 0)) {
fBalanceDepth = 6;
Log() << kWARNING << "Optimize must be a positive integer: set Optimize = " << fBalanceDepth << Endl;
}
Log() << kVERBOSE
<< "kNN options: \n"
<< " kNN = \n" << fnkNN
<< " UseKernel = \n" << fUseKernel
<< " SigmaFact = \n" << fSigmaFact
<< " ScaleFrac = \n" << fScaleFrac
<< " Kernel = \n" << fKernel
<< " Trim = \n" << fTrim
<< " Optimize = " << fBalanceDepth << Endl;
}
Bool_t TMVA::MethodKNN::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
{
if (type == Types::kClassification && numberClasses == 2) return kTRUE;
if (type == Types::kRegression) return kTRUE;
return kFALSE;
}
void TMVA::MethodKNN::Init()
{
fModule = new kNN::ModulekNN();
fSumOfWeightsS = 0;
fSumOfWeightsB = 0;
}
void TMVA::MethodKNN::MakeKNN()
{
if (!fModule) {
Log() << kFATAL << "ModulekNN is not created" << Endl;
}
fModule->Clear();
std::string option;
if (fScaleFrac > 0.0) {
option += "metric";
}
if (fTrim) {
option += "trim";
}
Log() << kINFO << "Creating kd-tree with " << fEvent.size() << " events" << Endl;
for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
fModule->Add(*event);
}
fModule->Fill(static_cast<UInt_t>(fBalanceDepth),
static_cast<UInt_t>(100.0*fScaleFrac),
option);
}
void TMVA::MethodKNN::Train()
{
Log() << kINFO << "<Train> start..." << Endl;
if (IsNormalised()) {
Log() << kINFO << "Input events are normalized - setting ScaleFrac to 0" << Endl;
fScaleFrac = 0.0;
}
if (!fEvent.empty()) {
Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
fEvent.clear();
}
if (GetNVariables() < 1)
Log() << kFATAL << "MethodKNN::Train() - mismatched or wrong number of event variables" << Endl;
Log() << kINFO << "Reading " << GetNEvents() << " events" << Endl;
for (UInt_t ievt = 0; ievt < GetNEvents(); ++ievt) {
const Event* evt_ = GetEvent(ievt);
Double_t weight = evt_->GetWeight();
if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
kNN::VarVec vvec(GetNVariables(), 0.0);
for (UInt_t ivar = 0; ivar < evt_ -> GetNVariables(); ++ivar) vvec[ivar] = evt_->GetValue(ivar);
Short_t event_type = 0;
if (evt_ -> IsSignal()) {
fSumOfWeightsS += weight;
event_type = 1;
}
else {
fSumOfWeightsB += weight;
event_type = 2;
}
kNN::Event event_knn(vvec, weight, event_type);
event_knn.SetTargets(evt_->GetTargets());
fEvent.push_back(event_knn);
}
Log() << kINFO
<< "Number of signal events " << fSumOfWeightsS << Endl
<< "Number of background events " << fSumOfWeightsB << Endl;
MakeKNN();
}
Double_t TMVA::MethodKNN::GetMvaValue( Double_t* err )
{
if (err != 0) *err = -1;
const Event *ev = GetEvent();
const Int_t nvar = GetNVariables();
const Double_t weight = ev->GetWeight();
const UInt_t knn = static_cast<UInt_t>(fnkNN);
kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
for (Int_t ivar = 0; ivar < nvar; ++ivar) {
vvec[ivar] = ev->GetValue(ivar);
}
const kNN::Event event_knn(vvec, weight, 3);
fModule->Find(event_knn, knn + 2);
const kNN::List &rlist = fModule->GetkNNList();
if (rlist.size() != knn + 2) {
Log() << kFATAL << "kNN result list is empty" << Endl;
return -100.0;
}
if (fUseLDA) return MethodKNN::getLDAValue(rlist, event_knn);
Bool_t use_gaus = false, use_poln = false;
if (fUseKernel) {
if (fKernel == "Gaus") use_gaus = true;
else if (fKernel == "Poln") use_poln = true;
}
Double_t kradius = -1.0;
if (use_poln) {
kradius = MethodKNN::getKernelRadius(rlist);
if (!(kradius > 0.0)) {
Log() << kFATAL << "kNN radius is not positive" << Endl;
return -100.0;
}
kradius = 1.0/TMath::Sqrt(kradius);
}
std::vector<Double_t> rms_vec;
if (use_gaus) {
rms_vec = TMVA::MethodKNN::getRMS(rlist, event_knn);
if (rms_vec.empty() || rms_vec.size() != event_knn.GetNVar()) {
Log() << kFATAL << "Failed to compute RMS vector" << Endl;
return -100.0;
}
}
UInt_t count_all = 0;
Double_t weight_all = 0, weight_sig = 0, weight_bac = 0;
for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
const kNN::Node<kNN::Event> &node = *(lit->first);
if (lit->second < 0.0) {
Log() << kFATAL << "A neighbor has negative distance to query event" << Endl;
}
else if (!(lit->second > 0.0)) {
Log() << kVERBOSE << "A neighbor has zero distance to query event" << Endl;
}
Double_t evweight = node.GetWeight();
if (use_gaus) evweight *= MethodKNN::GausKernel(event_knn, node.GetEvent(), rms_vec);
else if (use_poln) evweight *= MethodKNN::PolnKernel(TMath::Sqrt(lit->second)*kradius);
if (fUseWeight) weight_all += evweight;
else ++weight_all;
if (node.GetEvent().GetType() == 1) {
if (fUseWeight) weight_sig += evweight;
else ++weight_sig;
}
else if (node.GetEvent().GetType() == 2) {
if (fUseWeight) weight_bac += evweight;
else ++weight_bac;
}
else {
Log() << kFATAL << "Unknown type for training event" << Endl;
}
++count_all;
if (count_all >= knn) {
break;
}
}
if (!(count_all > 0)) {
Log() << kFATAL << "Size kNN result list is not positive" << Endl;
return -100.0;
}
if (count_all < knn) {
Log() << kDEBUG << "count_all and kNN have different size: " << count_all << " < " << knn << Endl;
}
if (!(weight_all > 0.0)) {
Log() << kFATAL << "kNN result total weight is not positive" << Endl;
return -100.0;
}
return weight_sig/weight_all;
}
const std::vector< Float_t >& TMVA::MethodKNN::GetRegressionValues()
{
if( fRegressionReturnVal == 0 )
fRegressionReturnVal = new std::vector<Float_t>;
else
fRegressionReturnVal->clear();
const Event *evt = GetEvent();
const Int_t nvar = GetNVariables();
const UInt_t knn = static_cast<UInt_t>(fnkNN);
std::vector<float> reg_vec;
kNN::VarVec vvec(static_cast<UInt_t>(nvar), 0.0);
for (Int_t ivar = 0; ivar < nvar; ++ivar) {
vvec[ivar] = evt->GetValue(ivar);
}
const kNN::Event event_knn(vvec, evt->GetWeight(), 3);
fModule->Find(event_knn, knn + 2);
const kNN::List &rlist = fModule->GetkNNList();
if (rlist.size() != knn + 2) {
Log() << kFATAL << "kNN result list is empty" << Endl;
return *fRegressionReturnVal;
}
Double_t weight_all = 0;
UInt_t count_all = 0;
for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
const kNN::Node<kNN::Event> &node = *(lit->first);
const kNN::VarVec &tvec = node.GetEvent().GetTargets();
const Double_t weight = node.GetEvent().GetWeight();
if (reg_vec.empty()) {
reg_vec= kNN::VarVec(tvec.size(), 0.0);
}
for(UInt_t ivar = 0; ivar < tvec.size(); ++ivar) {
if (fUseWeight) reg_vec[ivar] += tvec[ivar]*weight;
else reg_vec[ivar] += tvec[ivar];
}
if (fUseWeight) weight_all += weight;
else ++weight_all;
++count_all;
if (count_all == knn) {
break;
}
}
if (!(weight_all > 0.0)) {
Log() << kFATAL << "Total weight sum is not positive: " << weight_all << Endl;
return *fRegressionReturnVal;
}
for (UInt_t ivar = 0; ivar < reg_vec.size(); ++ivar) {
reg_vec[ivar] /= weight_all;
}
fRegressionReturnVal->insert(fRegressionReturnVal->begin(), reg_vec.begin(), reg_vec.end());
return *fRegressionReturnVal;
}
const TMVA::Ranking* TMVA::MethodKNN::CreateRanking()
{
return 0;
}
void TMVA::MethodKNN::AddWeightsXMLTo( void* parent ) const {
void* wght = gTools().AddChild(parent, "Weights");
gTools().AddAttr(wght,"NEvents",fEvent.size());
if (fEvent.size()>0) gTools().AddAttr(wght,"NVar",fEvent.begin()->GetNVar());
if (fEvent.size()>0) gTools().AddAttr(wght,"NTgt",fEvent.begin()->GetNTgt());
for (kNN::EventVec::const_iterator event = fEvent.begin(); event != fEvent.end(); ++event) {
std::stringstream s("");
s.precision( 16 );
for (UInt_t ivar = 0; ivar < event->GetNVar(); ++ivar) {
if (ivar>0) s << " ";
s << std::scientific << event->GetVar(ivar);
}
for (UInt_t itgt = 0; itgt < event->GetNTgt(); ++itgt) {
s << " " << std::scientific << event->GetTgt(itgt);
}
void* evt = gTools().AddChild(wght, "Event", s.str().c_str());
gTools().AddAttr(evt,"Type", event->GetType());
gTools().AddAttr(evt,"Weight", event->GetWeight());
}
}
void TMVA::MethodKNN::ReadWeightsFromXML( void* wghtnode ) {
void* ch = gTools().GetChild(wghtnode);
UInt_t nvar = 0, ntgt = 0;
gTools().ReadAttr( wghtnode, "NVar", nvar );
gTools().ReadAttr( wghtnode, "NTgt", ntgt );
Short_t evtType(0);
Double_t evtWeight(0);
while (ch) {
kNN::VarVec vvec(nvar, 0);
kNN::VarVec tvec(ntgt, 0);
gTools().ReadAttr( ch, "Type", evtType );
gTools().ReadAttr( ch, "Weight", evtWeight );
std::stringstream s( gTools().GetContent(ch) );
for(UInt_t ivar=0; ivar<nvar; ivar++)
s >> vvec[ivar];
for(UInt_t itgt=0; itgt<ntgt; itgt++)
s >> tvec[itgt];
ch = gTools().GetNextChild(ch);
kNN::Event event_knn(vvec, evtWeight, evtType, tvec);
fEvent.push_back(event_knn);
}
MakeKNN();
}
void TMVA::MethodKNN::ReadWeightsFromStream(istream& is)
{
Log() << kINFO << "Starting ReadWeightsFromStream(istream& is) function..." << Endl;
if (!fEvent.empty()) {
Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
fEvent.clear();
}
UInt_t nvar = 0;
while (!is.eof()) {
std::string line;
std::getline(is, line);
if (line.empty() || line.find("#") != std::string::npos) {
continue;
}
UInt_t count = 0;
std::string::size_type pos=0;
while( (pos=line.find(',',pos)) != std::string::npos ) { count++; pos++; }
if (nvar == 0) {
nvar = count - 2;
}
if (count < 3 || nvar != count - 2) {
Log() << kFATAL << "Missing comma delimeter(s)" << Endl;
}
Int_t ievent = -1, type = -1;
Double_t weight = -1.0;
kNN::VarVec vvec(nvar, 0.0);
UInt_t vcount = 0;
std::string::size_type prev = 0;
for (std::string::size_type ipos = 0; ipos < line.size(); ++ipos) {
if (line[ipos] != ',' && ipos + 1 != line.size()) {
continue;
}
if (!(ipos > prev)) {
Log() << kFATAL << "Wrong substring limits" << Endl;
}
std::string vstring = line.substr(prev, ipos - prev);
if (ipos + 1 == line.size()) {
vstring = line.substr(prev, ipos - prev + 1);
}
if (vstring.empty()) {
Log() << kFATAL << "Failed to parse string" << Endl;
}
if (vcount == 0) {
ievent = std::atoi(vstring.c_str());
}
else if (vcount == 1) {
type = std::atoi(vstring.c_str());
}
else if (vcount == 2) {
weight = std::atof(vstring.c_str());
}
else if (vcount - 3 < vvec.size()) {
vvec[vcount - 3] = std::atof(vstring.c_str());
}
else {
Log() << kFATAL << "Wrong variable count" << Endl;
}
prev = ipos + 1;
++vcount;
}
fEvent.push_back(kNN::Event(vvec, weight, type));
}
Log() << kINFO << "Read " << fEvent.size() << " events from text file" << Endl;
MakeKNN();
}
void TMVA::MethodKNN::WriteWeightsToStream(TFile &rf) const
{
Log() << kINFO << "Starting WriteWeightsToStream(TFile &rf) function..." << Endl;
if (fEvent.empty()) {
Log() << kWARNING << "MethodKNN contains no events " << Endl;
return;
}
kNN::Event *event = new kNN::Event();
TTree *tree = new TTree("knn", "event tree");
tree->SetDirectory(0);
tree->Branch("event", "TMVA::kNN::Event", &event);
Double_t size = 0.0;
for (kNN::EventVec::const_iterator it = fEvent.begin(); it != fEvent.end(); ++it) {
(*event) = (*it);
size += tree->Fill();
}
rf.WriteTObject(tree, "knn", "Overwrite");
size /= 1048576.0;
Log() << kINFO << "Wrote " << size << "MB and " << fEvent.size()
<< " events to ROOT file" << Endl;
delete tree;
delete event;
}
void TMVA::MethodKNN::ReadWeightsFromStream(TFile &rf)
{
Log() << kINFO << "Starting ReadWeightsFromStream(TFile &rf) function..." << Endl;
if (!fEvent.empty()) {
Log() << kINFO << "Erasing " << fEvent.size() << " previously stored events" << Endl;
fEvent.clear();
}
TTree *tree = dynamic_cast<TTree *>(rf.Get("knn"));
if (!tree) {
Log() << kFATAL << "Failed to find knn tree" << Endl;
return;
}
kNN::Event *event = new kNN::Event();
tree->SetBranchAddress("event", &event);
const Int_t nevent = tree->GetEntries();
Double_t size = 0.0;
for (Int_t i = 0; i < nevent; ++i) {
size += tree->GetEntry(i);
fEvent.push_back(*event);
}
size /= 1048576.0;
Log() << kINFO << "Read " << size << "MB and " << fEvent.size()
<< " events from ROOT file" << Endl;
delete event;
MakeKNN();
}
void TMVA::MethodKNN::MakeClassSpecific( std::ostream& fout, const TString& className ) const
{
fout << " // not implemented for class: \"" << className << "\"" << std::endl;
fout << "};" << std::endl;
}
void TMVA::MethodKNN::GetHelpMessage() const
{
Log() << Endl;
Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "The k-nearest neighbor (k-NN) algorithm is a multi-dimensional classification" << Endl
<< "and regression algorithm. Similarly to other TMVA algorithms, k-NN uses a set of" << Endl
<< "training events for which a classification category/regression target is known. " << Endl
<< "The k-NN method compares a test event to all training events using a distance " << Endl
<< "function, which is an Euclidean distance in a space defined by the input variables. "<< Endl
<< "The k-NN method, as implemented in TMVA, uses a kd-tree algorithm to perform a" << Endl
<< "quick search for the k events with shortest distance to the test event. The method" << Endl
<< "returns a fraction of signal events among the k neighbors. It is recommended" << Endl
<< "that a histogram which stores the k-NN decision variable is binned with k+1 bins" << Endl
<< "between 0 and 1." << Endl;
Log() << Endl;
Log() << gTools().Color("bold") << "--- Performance tuning via configuration options: "
<< gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "The k-NN method estimates a density of signal and background events in a "<< Endl
<< "neighborhood around the test event. The method assumes that the density of the " << Endl
<< "signal and background events is uniform and constant within the neighborhood. " << Endl
<< "k is an adjustable parameter and it determines an average size of the " << Endl
<< "neighborhood. Small k values (less than 10) are sensitive to statistical " << Endl
<< "fluctuations and large (greater than 100) values might not sufficiently capture " << Endl
<< "local differences between events in the training set. The speed of the k-NN" << Endl
<< "method also increases with larger values of k. " << Endl;
Log() << Endl;
Log() << "The k-NN method assigns equal weight to all input variables. Different scales " << Endl
<< "among the input variables is compensated using ScaleFrac parameter: the input " << Endl
<< "variables are scaled so that the widths for central ScaleFrac*100% events are " << Endl
<< "equal among all the input variables." << Endl;
Log() << Endl;
Log() << gTools().Color("bold") << "--- Additional configuration options: "
<< gTools().Color("reset") << Endl;
Log() << Endl;
Log() << "The method inclues an option to use a Gaussian kernel to smooth out the k-NN" << Endl
<< "response. The kernel re-weights events using a distance to the test event." << Endl;
}
Double_t TMVA::MethodKNN::PolnKernel(const Double_t value) const
{
const Double_t avalue = TMath::Abs(value);
if (!(avalue < 1.0)) {
return 0.0;
}
const Double_t prod = 1.0 - avalue * avalue * avalue;
return (prod * prod * prod);
}
Double_t TMVA::MethodKNN::GausKernel(const kNN::Event &event_knn,
const kNN::Event &event, const std::vector<Double_t> &svec) const
{
if (event_knn.GetNVar() != event.GetNVar() || event_knn.GetNVar() != svec.size()) {
Log() << kFATAL << "Mismatched vectors in Gaussian kernel function" << Endl;
return 0.0;
}
double sum_exp = 0.0;
for(unsigned int ivar = 0; ivar < event_knn.GetNVar(); ++ivar) {
const Double_t diff_ = event.GetVar(ivar) - event_knn.GetVar(ivar);
const Double_t sigm_ = svec[ivar];
if (!(sigm_ > 0.0)) {
Log() << kFATAL << "Bad sigma value = " << sigm_ << Endl;
return 0.0;
}
sum_exp += diff_*diff_/(2.0*sigm_*sigm_);
}
return std::exp(-sum_exp);
}
Double_t TMVA::MethodKNN::getKernelRadius(const kNN::List &rlist) const
{
Double_t kradius = -1.0;
UInt_t kcount = 0;
const UInt_t knn = static_cast<UInt_t>(fnkNN);
for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
{
if (!(lit->second > 0.0)) continue;
if (kradius < lit->second || kradius < 0.0) kradius = lit->second;
++kcount;
if (kcount >= knn) break;
}
return kradius;
}
const std::vector<Double_t> TMVA::MethodKNN::getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const
{
std::vector<Double_t> rvec;
UInt_t kcount = 0;
const UInt_t knn = static_cast<UInt_t>(fnkNN);
for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit)
{
if (!(lit->second > 0.0)) continue;
const kNN::Node<kNN::Event> *node_ = lit -> first;
const kNN::Event &event_ = node_-> GetEvent();
if (rvec.empty()) {
rvec.insert(rvec.end(), event_.GetNVar(), 0.0);
}
else if (rvec.size() != event_.GetNVar()) {
Log() << kFATAL << "Wrong number of variables, should never happen!" << Endl;
rvec.clear();
return rvec;
}
for(unsigned int ivar = 0; ivar < event_.GetNVar(); ++ivar) {
const Double_t diff_ = event_.GetVar(ivar) - event_knn.GetVar(ivar);
rvec[ivar] += diff_*diff_;
}
++kcount;
if (kcount >= knn) break;
}
if (kcount < 1) {
Log() << kFATAL << "Bad event kcount = " << kcount << Endl;
rvec.clear();
return rvec;
}
for(unsigned int ivar = 0; ivar < rvec.size(); ++ivar) {
if (!(rvec[ivar] > 0.0)) {
Log() << kFATAL << "Bad RMS value = " << rvec[ivar] << Endl;
rvec.clear();
return rvec;
}
rvec[ivar] = std::abs(fSigmaFact)*std::sqrt(rvec[ivar]/kcount);
}
return rvec;
}
Double_t TMVA::MethodKNN::getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn)
{
LDAEvents sig_vec, bac_vec;
for (kNN::List::const_iterator lit = rlist.begin(); lit != rlist.end(); ++lit) {
const kNN::Node<kNN::Event> &node = *(lit->first);
const kNN::VarVec &tvec = node.GetEvent().GetVars();
if (node.GetEvent().GetType() == 1) {
sig_vec.push_back(tvec);
}
else if (node.GetEvent().GetType() == 2) {
bac_vec.push_back(tvec);
}
else {
Log() << kFATAL << "Unknown type for training event" << Endl;
}
}
fLDA.Initialize(sig_vec, bac_vec);
return fLDA.GetProb(event_knn.GetVars(), 1);
}