#include "Riostream.h"
#include "RooFit.h"
#include "RooFFTConvPdf.h"
#include "RooAbsReal.h"
#include "RooMsgService.h"
#include "RooDataHist.h"
#include "RooHistPdf.h"
#include "RooRealVar.h"
#include "TComplex.h"
#include "TVirtualFFT.h"
#include "RooGenContext.h"
#include "RooConvGenContext.h"
#include "RooBinning.h"
#include "RooLinearVar.h"
#include "RooCustomizer.h"
#include "RooGlobalFunc.h"
#include "RooLinearVar.h"
#include "RooConstVar.h"
#include "TClass.h"
#include "TSystem.h"
using namespace std ;
ClassImp(RooFFTConvPdf)
RooFFTConvPdf::RooFFTConvPdf(const char *name, const char *title, RooRealVar& convVar, RooAbsPdf& pdf1, RooAbsPdf& pdf2, Int_t ipOrder) :
RooAbsCachedPdf(name,title,ipOrder),
_x("!x","Convolution Variable",this,convVar),
_xprime("!xprime","External Convolution Variable",this,0),
_pdf1("!pdf1","pdf1",this,pdf1,kFALSE),
_pdf2("!pdf2","pdf2",this,pdf2,kFALSE),
_params("!params","effective parameters",this),
_bufFrac(0.1),
_bufStrat(Extend),
_shift1(0),
_shift2(0),
_cacheObs("!cacheObs","Cached observables",this,kFALSE,kFALSE)
{
if (!convVar.hasBinning("cache")) {
convVar.setBinning(convVar.getBinning(),"cache") ;
}
_shift2 = (convVar.getMax("cache")+convVar.getMin("cache"))/2 ;
calcParams() ;
}
RooFFTConvPdf::RooFFTConvPdf(const char *name, const char *title, RooAbsReal& pdfConvVar, RooRealVar& convVar, RooAbsPdf& pdf1, RooAbsPdf& pdf2, Int_t ipOrder) :
RooAbsCachedPdf(name,title,ipOrder),
_x("!x","Convolution Variable",this,convVar,kFALSE,kFALSE),
_xprime("!xprime","External Convolution Variable",this,pdfConvVar),
_pdf1("!pdf1","pdf1",this,pdf1,kFALSE),
_pdf2("!pdf2","pdf2",this,pdf2,kFALSE),
_params("!params","effective parameters",this),
_bufFrac(0.1),
_bufStrat(Extend),
_shift1(0),
_shift2(0),
_cacheObs("!cacheObs","Cached observables",this,kFALSE,kFALSE)
{
if (!convVar.hasBinning("cache")) {
convVar.setBinning(convVar.getBinning(),"cache") ;
}
_shift2 = (convVar.getMax("cache")+convVar.getMin("cache"))/2 ;
calcParams() ;
}
RooFFTConvPdf::RooFFTConvPdf(const RooFFTConvPdf& other, const char* name) :
RooAbsCachedPdf(other,name),
_x("!x",this,other._x),
_xprime("!xprime",this,other._xprime),
_pdf1("!pdf1",this,other._pdf1),
_pdf2("!pdf2",this,other._pdf2),
_params("!params",this,other._params),
_bufFrac(other._bufFrac),
_bufStrat(other._bufStrat),
_shift1(other._shift1),
_shift2(other._shift2),
_cacheObs("!cacheObs",this,other._cacheObs)
{
}
RooFFTConvPdf::~RooFFTConvPdf()
{
}
const char* RooFFTConvPdf::inputBaseName() const
{
static TString name ;
name = _pdf1.arg().GetName() ;
name.Append("_CONV_") ;
name.Append(_pdf2.arg().GetName()) ;
return name.Data() ;
}
RooFFTConvPdf::PdfCacheElem* RooFFTConvPdf::createCache(const RooArgSet* nset) const
{
return new FFTCacheElem(*this,nset) ;
}
RooFFTConvPdf::FFTCacheElem::FFTCacheElem(const RooFFTConvPdf& self, const RooArgSet* nsetIn) :
PdfCacheElem(self,nsetIn),
fftr2c1(0),fftr2c2(0),fftc2r(0)
{
RooAbsPdf* clonePdf1 = (RooAbsPdf*) self._pdf1.arg().cloneTree() ;
RooAbsPdf* clonePdf2 = (RooAbsPdf*) self._pdf2.arg().cloneTree() ;
clonePdf1->attachDataSet(*hist()) ;
clonePdf2->attachDataSet(*hist()) ;
RooRealVar* convObs = (RooRealVar*) hist()->get()->find(self._x.arg().GetName()) ;
string refName = Form("refrange_fft_%s",self.GetName()) ;
convObs->setRange(refName.c_str(),convObs->getMin(),convObs->getMax()) ;
if (self._shift1!=0) {
RooLinearVar* shiftObs1 = new RooLinearVar(Form("%s_shifted_FFTBuffer1",convObs->GetName()),"shiftObs1",
*convObs,RooFit::RooConst(1),RooFit::RooConst(-1*self._shift1)) ;
RooArgSet clonedBranches1 ;
RooCustomizer cust(*clonePdf1,"fft") ;
cust.replaceArg(*convObs,*shiftObs1) ;
pdf1Clone = (RooAbsPdf*) cust.build() ;
pdf1Clone->addOwnedComponents(*shiftObs1) ;
pdf1Clone->addOwnedComponents(*clonePdf1) ;
} else {
pdf1Clone = clonePdf1 ;
}
if (self._shift2!=0) {
RooLinearVar* shiftObs2 = new RooLinearVar(Form("%s_shifted_FFTBuffer2",convObs->GetName()),"shiftObs2",
*convObs,RooFit::RooConst(1),RooFit::RooConst(-1*self._shift2)) ;
RooArgSet clonedBranches2 ;
RooCustomizer cust(*clonePdf2,"fft") ;
cust.replaceArg(*convObs,*shiftObs2) ;
pdf1Clone->addOwnedComponents(*shiftObs2) ;
pdf1Clone->addOwnedComponents(*clonePdf2) ;
pdf2Clone = (RooAbsPdf*) cust.build() ;
} else {
pdf2Clone = clonePdf2 ;
}
RooArgSet* fftParams = self.getParameters(*convObs) ;
fftParams->remove(*hist()->get(),kTRUE,kTRUE) ;
pdf1Clone->recursiveRedirectServers(*fftParams) ;
pdf2Clone->recursiveRedirectServers(*fftParams) ;
pdf1Clone->fixAddCoefRange(refName.c_str()) ;
pdf2Clone->fixAddCoefRange(refName.c_str()) ;
delete fftParams ;
Int_t N = convObs->numBins() ;
Int_t Nbuf = static_cast<Int_t>((N*self.bufferFraction())/2 + 0.5) ;
Double_t obw = (convObs->getMax() - convObs->getMin())/N ;
Int_t N2 = N+2*Nbuf ;
scanBinning = new RooUniformBinning (convObs->getMin()-Nbuf*obw,convObs->getMax()+Nbuf*obw,N2) ;
histBinning = convObs->getBinning().clone() ;
hist()->setDirtyProp(kFALSE) ;
convObs->setOperMode(ADirty,kTRUE) ;
}
TString RooFFTConvPdf::histNameSuffix() const
{
return TString(Form("_BufFrac%3.1f_BufStrat%d",_bufFrac,_bufStrat)) ;
}
RooArgList RooFFTConvPdf::FFTCacheElem::containedArgs(Action a)
{
RooArgList ret(PdfCacheElem::containedArgs(a)) ;
ret.add(*pdf1Clone) ;
ret.add(*pdf2Clone) ;
if (pdf1Clone->ownedComponents()) {
ret.add(*pdf1Clone->ownedComponents()) ;
}
if (pdf2Clone->ownedComponents()) {
ret.add(*pdf2Clone->ownedComponents()) ;
}
return ret ;
}
RooFFTConvPdf::FFTCacheElem::~FFTCacheElem()
{
delete fftr2c1 ;
delete fftr2c2 ;
delete fftc2r ;
delete pdf1Clone ;
delete pdf2Clone ;
delete histBinning ;
delete scanBinning ;
}
void RooFFTConvPdf::fillCacheObject(RooAbsCachedPdf::PdfCacheElem& cache) const
{
RooDataHist& cacheHist = *cache.hist() ;
((FFTCacheElem&)cache).pdf1Clone->setOperMode(ADirty,kTRUE) ;
((FFTCacheElem&)cache).pdf2Clone->setOperMode(ADirty,kTRUE) ;
RooArgSet otherObs ;
RooArgSet(*cacheHist.get()).snapshot(otherObs) ;
RooAbsArg* histArg = otherObs.find(_x.arg().GetName()) ;
if (histArg) {
otherObs.remove(*histArg,kTRUE,kTRUE) ;
delete histArg ;
}
if (otherObs.getSize()==0) {
fillCacheSlice((FFTCacheElem&)cache,RooArgSet()) ;
return ;
}
Int_t n = otherObs.getSize() ;
Int_t* binCur = new Int_t[n+1] ;
Int_t* binMax = new Int_t[n+1] ;
Int_t curObs = 0 ;
RooAbsLValue** obsLV = new RooAbsLValue*[n] ;
TIterator* iter = otherObs.createIterator() ;
RooAbsArg* arg ;
Int_t i(0) ;
while((arg=(RooAbsArg*)iter->Next())) {
RooAbsLValue* lvarg = dynamic_cast<RooAbsLValue*>(arg) ;
obsLV[i] = lvarg ;
binCur[i] = 0 ;
binMax[i] = lvarg->numBins(binningName())-1 ;
i++ ;
}
delete iter ;
Bool_t loop(kTRUE) ;
while(loop) {
for (Int_t j=0 ; j<n ; j++) { obsLV[j]->setBin(binCur[j],binningName()) ; }
fillCacheSlice((FFTCacheElem&)cache,otherObs) ;
while(binCur[curObs]==binMax[curObs]) {
binCur[curObs]=0 ;
curObs++ ;
if (curObs==n) {
loop=kFALSE ;
break ;
}
}
binCur[curObs]++ ;
curObs=0 ;
}
delete[] obsLV ;
delete[] binMax ;
delete[] binCur ;
}
void RooFFTConvPdf::fillCacheSlice(FFTCacheElem& aux, const RooArgSet& slicePos) const
{
RooDataHist& cacheHist = *aux.hist() ;
Int_t N,N2,binShift1,binShift2 ;
RooRealVar* histX = (RooRealVar*) cacheHist.get()->find(_x.arg().GetName()) ;
if (_bufStrat==Extend) histX->setBinning(*aux.scanBinning) ;
Double_t* input1 = scanPdf((RooRealVar&)_x.arg(),*aux.pdf1Clone,cacheHist,slicePos,N,N2,binShift1,_shift1) ;
Double_t* input2 = scanPdf((RooRealVar&)_x.arg(),*aux.pdf2Clone,cacheHist,slicePos,N,N2,binShift2,_shift2) ;
if (_bufStrat==Extend) histX->setBinning(*aux.histBinning) ;
if (!aux.fftr2c1) {
aux.fftr2c1 = TVirtualFFT::FFT(1, &N2, "R2CK");
aux.fftr2c2 = TVirtualFFT::FFT(1, &N2, "R2CK");
aux.fftc2r = TVirtualFFT::FFT(1, &N2, "C2RK");
}
aux.fftr2c1->SetPoints(input1);
aux.fftr2c1->Transform();
aux.fftr2c2->SetPoints(input2);
aux.fftr2c2->Transform();
for (Int_t i=0 ; i<N2/2+1 ; i++) {
Double_t re1,re2,im1,im2 ;
aux.fftr2c1->GetPointComplex(i,re1,im1) ;
aux.fftr2c2->GetPointComplex(i,re2,im2) ;
Double_t re = re1*re2 - im1*im2 ;
Double_t im = re1*im2 + re2*im1 ;
TComplex t(re,im) ;
aux.fftc2r->SetPointComplex(i,t) ;
}
aux.fftc2r->Transform() ;
Int_t totalShift = binShift1 + (N2-N)/2 ;
TIterator* iter = const_cast<RooDataHist&>(cacheHist).sliceIterator(const_cast<RooAbsReal&>(_x.arg()),slicePos) ;
for (Int_t i =0 ; i<N ; i++) {
Int_t j = i + totalShift ;
while (j<0) j+= N2 ;
while (j>=N2) j-= N2 ;
iter->Next() ;
cacheHist.set(aux.fftc2r->GetPointReal(j)) ;
}
delete iter ;
delete[] input1 ;
delete[] input2 ;
}
Double_t* RooFFTConvPdf::scanPdf(RooRealVar& obs, RooAbsPdf& pdf, const RooDataHist& hist, const RooArgSet& slicePos,
Int_t& N, Int_t& N2, Int_t& zeroBin, Double_t shift) const
{
RooRealVar* histX = (RooRealVar*) hist.get()->find(obs.GetName()) ;
N = histX->numBins(binningName()) ;
Int_t Nbuf = static_cast<Int_t>((N*bufferFraction())/2 + 0.5) ;
N2 = N+2*Nbuf ;
Double_t* array = new Double_t[N2] ;
hist.get(slicePos) ;
zeroBin = 0 ;
if (histX->getMax()>=0 && histX->getMin()<=0) {
zeroBin = histX->getBinning().binNumber(0) ;
} else if (histX->getMin()>0) {
Double_t bw = (histX->getMax() - histX->getMin())/N2 ;
zeroBin = Int_t(-histX->getMin()/bw) ;
} else {
Double_t bw = (histX->getMax() - histX->getMin())/N2 ;
zeroBin = Int_t(-1*histX->getMax()/bw) ;
}
Int_t binShift = Int_t((N2* shift) / (histX->getMax()-histX->getMin())) ;
zeroBin += binShift ;
while(zeroBin>=N2) zeroBin-= N2 ;
while(zeroBin<0) zeroBin+= N2 ;
Double_t *tmp = new Double_t[N2] ;
Int_t k(0) ;
switch(_bufStrat) {
case Extend:
for (k=0 ; k<N2 ; k++) {
histX->setBin(k) ;
tmp[k] = pdf.getVal(hist.get()) ;
}
break ;
case Flat:
{
histX->setBin(0) ;
Double_t val = pdf.getVal(hist.get()) ;
for (k=0 ; k<Nbuf ; k++) {
tmp[k] = val ;
}
for (k=0 ; k<N ; k++) {
histX->setBin(k) ;
tmp[k+Nbuf] = pdf.getVal(hist.get()) ;
}
histX->setBin(N-1) ;
val = pdf.getVal(hist.get()) ;
for (k=0 ; k<Nbuf ; k++) {
tmp[N+Nbuf+k] = val ;
}
}
break ;
case Mirror:
for (k=0 ; k<N ; k++) {
histX->setBin(k) ;
tmp[k+Nbuf] = pdf.getVal(hist.get()) ;
}
for (k=1 ; k<=Nbuf ; k++) {
histX->setBin(k) ;
tmp[Nbuf-k] = pdf.getVal(hist.get()) ;
histX->setBin(N-k) ;
tmp[Nbuf+N+k-1] = pdf.getVal(hist.get()) ;
}
break ;
}
for (Int_t i=0 ; i<N2 ; i++) {
Int_t j = i - (zeroBin) ;
if (j<0) j+= N2 ;
if (j>=N2) j-= N2 ;
array[i] = tmp[j] ;
}
delete[] tmp ;
return array ;
}
RooArgSet* RooFFTConvPdf::actualObservables(const RooArgSet& nset) const
{
RooArgSet* obs1 = _pdf1.arg().getObservables(nset) ;
RooArgSet* obs2 = _pdf2.arg().getObservables(nset) ;
obs1->add(*obs2,kTRUE) ;
if (nset.contains(_x.arg())) {
TIterator* iter = obs1->createIterator() ;
RooAbsArg* arg ;
RooArgSet killList ;
while((arg=(RooAbsArg*)iter->Next())) {
if (arg->IsA()->InheritsFrom(RooAbsReal::Class()) && !_cacheObs.find(arg->GetName())) {
killList.add(*arg) ;
}
}
delete iter ;
obs1->remove(killList) ;
obs1->add(_x.arg(),kTRUE) ;
obs1->add(_cacheObs) ;
delete obs2 ;
} else {
if (_cacheObs.getSize()>0) {
TIterator* iter = obs1->createIterator() ;
RooAbsArg* arg ;
RooArgSet killList ;
while((arg=(RooAbsArg*)iter->Next())) {
if (arg->IsA()->InheritsFrom(RooAbsReal::Class()) && !_cacheObs.find(arg->GetName())) {
killList.add(*arg) ;
}
}
delete iter ;
obs1->remove(killList) ;
}
obs1->add(_x.arg(),kTRUE) ;
delete obs2 ;
}
return obs1 ;
}
RooArgSet* RooFFTConvPdf::actualParameters(const RooArgSet& nset) const
{
RooArgSet* vars = getVariables() ;
RooArgSet* obs = actualObservables(nset) ;
vars->remove(*obs) ;
delete obs ;
return vars ;
}
RooAbsArg& RooFFTConvPdf::pdfObservable(RooAbsArg& histObservable) const
{
if (_xprime.absArg() && string(histObservable.GetName())==_x.absArg()->GetName()) {
return (*_xprime.absArg()) ;
}
return histObservable ;
}
RooAbsGenContext* RooFFTConvPdf::genContext(const RooArgSet &vars, const RooDataSet *prototype,
const RooArgSet* auxProto, Bool_t verbose) const
{
RooArgSet vars2(vars) ;
vars2.remove(_x.arg(),kTRUE,kTRUE) ;
Int_t numAddDep = vars2.getSize() ;
RooArgSet dummy ;
Bool_t pdfCanDir = (((RooAbsPdf&)_pdf1.arg()).getGenerator(_x.arg(),dummy) != 0 && \
((RooAbsPdf&)_pdf1.arg()).isDirectGenSafe(_x.arg())) ;
Bool_t resCanDir = (((RooAbsPdf&)_pdf2.arg()).getGenerator(_x.arg(),dummy) !=0 &&
((RooAbsPdf&)_pdf2.arg()).isDirectGenSafe(_x.arg())) ;
if (pdfCanDir) {
cxcoutI(Generation) << "RooFFTConvPdf::genContext() input p.d.f " << _pdf1.arg().GetName()
<< " has internal generator that is safe to use in current context" << endl ;
}
if (resCanDir) {
cxcoutI(Generation) << "RooFFTConvPdf::genContext() input p.d.f. " << _pdf2.arg().GetName()
<< " has internal generator that is safe to use in current context" << endl ;
}
if (numAddDep>0) {
cxcoutI(Generation) << "RooFFTConvPdf::genContext() generation requested for observables other than the convolution observable " << _x.arg().GetName() << endl ;
}
if (numAddDep>0 || !pdfCanDir || !resCanDir) {
cxcoutI(Generation) << "RooFFTConvPdf::genContext() selecting accept/reject generator context because one or both of the input "
<< "p.d.f.s cannot use internal generator and/or "
<< "observables other than the convolution variable are requested for generation" << endl ;
return new RooGenContext(*this,vars,prototype,auxProto,verbose) ;
}
cxcoutI(Generation) << "RooFFTConvPdf::genContext() selecting specialized convolution generator context as both input "
<< "p.d.fs are safe for internal generator and only "
<< "the convolution observables is requested for generation" << endl ;
return new RooConvGenContext(*this,vars,prototype,auxProto,verbose) ;
}
void RooFFTConvPdf::setBufferFraction(Double_t frac)
{
if (frac<0) {
coutE(InputArguments) << "RooFFTConvPdf::setBufferFraction(" << GetName() << ") fraction should be greater than or equal to zero" << endl ;
return ;
}
_bufFrac = frac ;
_cacheMgr.sterilize() ;
}
void RooFFTConvPdf::setBufferStrategy(BufStrat bs)
{
_bufStrat = bs ;
}
void RooFFTConvPdf::printMetaArgs(ostream& os) const
{
os << _pdf1.arg().GetName() << "(" << _x.arg().GetName() << ") (*) " << _pdf2.arg().GetName() << "(" << _x.arg().GetName() << ") " ;
}
void RooFFTConvPdf::calcParams()
{
RooArgSet* params1 = _pdf1.arg().getParameters(_x.arg()) ;
RooArgSet* params2 = _pdf2.arg().getParameters(_x.arg()) ;
_params.removeAll() ;
_params.add(*params1) ;
_params.add(*params2,kTRUE) ;
delete params1 ;
delete params2 ;
}
Bool_t RooFFTConvPdf::redirectServersHook(const RooAbsCollection& , Bool_t , Bool_t , Bool_t )
{
return kFALSE ;
}