56void assignSpan(std::span<T> &to, std::span<T>
const &from)
79 auto log = [](std::string_view message) {
84 log(
"using generic CPU library compiled with no vectorizations");
89 log(
"using CUDA computation library");
104 std::shared_ptr<RooBatchCompute::AbsBuffer>
buffer;
156 throw std::runtime_error(
"Can't create Evaluator in CUDA mode because RooBatchCompute CUDA could not be loaded!");
172 std::map<RooFit::Detail::DataKey, NodeInfo *>
nodeInfos;
176 std::size_t iNode = 0;
184 nodeInfo.originalOperMode = arg->operMode();
191 arg->setDataToken(iNode);
201 info.serverInfos.reserve(
info.absArg->servers().size());
212 _nodes.back().isValueServer =
true;
213 for (
auto iter =
_nodes.rbegin(); iter !=
_nodes.rend(); ++iter) {
214 if (!iter->isValueServer)
257 throw std::runtime_error(
"Evaluator can only take device array as input in CUDA mode!");
272 info.fromArrayInput =
true;
273 info.absArg->setDataToken(
info.iNode);
281 if (
info.outputSize <= 1) {
309 std::map<RooFit::Detail::DataKey, std::size_t>
sizeMap;
311 if (
info.fromArrayInput) {
321 auto found =
sizeMap.find(key);
322 return found !=
sizeMap.
end() ? found->second : -1;
334 if (!
info.isScalar()) {
351 if (!
info.isVariable) {
352 info.absArg->resetDataToken();
359 using namespace Detail;
361 const std::size_t
nOut =
info.outputSize;
363 double *buffer =
nullptr;
365 buffer = &
info.scalarBuffer;
373 <<
" could not be evaluated on the GPU because the class doesn't support it. "
374 "Consider requesting or implementing it to benefit from a speed up."
376 info.hasLogged =
true;
382 buffer =
info.buffer->hostWritePtr();
389 if (
info.isCategory) {
394 throw std::runtime_error(
"RooFit::Evaluator - non-scalar category values are not supported!");
402 if (
info.copyAfterEvaluation) {
415 auto *var =
static_cast<RooRealVar const *
>(node);
416 if (
nodeInfo.lastSetValCount != var->valueResetCounter()) {
417 nodeInfo.lastSetValCount = var->valueResetCounter();
432 if (
nodeInfo.lastCatVal != cat->getCurrentIndex()) {
433 nodeInfo.lastCatVal = cat->getCurrentIndex();
487 info.remClients =
info.clientInfos.size();
488 info.remServers =
info.serverInfos.size();
489 if (
info.buffer && !
info.fromArrayInput) {
496 if (
info.remServers == 0 &&
info.computeInGPU) {
506 info.remServers = -2;
522 for (; it !=
_nodes.end(); it++) {
523 if (it->remServers == 0 && !it->computeInGPU)
529 std::this_thread::sleep_for(std::chrono::milliseconds(1));
536 info.remServers = -2;
538 if (!
info.fromArrayInput) {
561 using namespace Detail;
563 info.remServers = -1;
573 const std::size_t
nOut =
info.outputSize;
575 double *buffer =
nullptr;
577 buffer = &
info.scalarBuffer;
582 buffer =
info.buffer->deviceWritePtr();
588 if (
info.copyAfterEvaluation) {
599 info.computeInGPU =
false;
600 if (!
info.absArg->canComputeBatchWithCuda()) {
605 info.computeInGPU =
true;
613 info.copyAfterEvaluation =
false;
615 if (!
info.isScalar()) {
618 info.copyAfterEvaluation =
true;
637 std::cout <<
"--- RooFit BatchMode evaluation ---\n";
639 std::vector<int>
widths{9, 37, 20, 9, 10, 20};
642 const char separator =
' ';
643 os << separator << std::left << std::setw(
widths[
iCol]) << std::setfill(separator) << t;
652 for (
int i = 0; i <
n; i++) {
671 for (std::size_t iNode = 0; iNode <
_nodes.size(); ++iNode) {
728 if (
nodeInfo.absArg->isReducerNode()) {
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char mode
const_iterator begin() const
const_iterator end() const
Common abstract base class for objects that represent a value and a "shape" in RooFit.
OperMode operMode() const
Query the operation mode of this node.
A space to attach TBranches.
virtual bool add(const RooAbsArg &var, bool silent=false)
Add the specified argument to list.
void sort(bool reverse=false)
Sort collection using std::sort and name comparison.
Abstract base class for objects that represent a real value and implements functionality common to al...
RooArgSet is a container object that can hold multiple RooAbsArg objects.
Minimal configuration struct to steer the evaluation of a single node with the RooBatchCompute librar...
void setCudaStream(CudaInterface::CudaStream *cudaStream)
virtual void deleteCudaEvent(CudaInterface::CudaEvent *) const =0
virtual CudaInterface::CudaEvent * newCudaEvent(bool forTiming) const =0
virtual void cudaEventRecord(CudaInterface::CudaEvent *, CudaInterface::CudaStream *) const =0
virtual std::unique_ptr< AbsBufferManager > createBufferManager() const =0
virtual void cudaStreamWaitForEvent(CudaInterface::CudaStream *, CudaInterface::CudaEvent *) const =0
virtual CudaInterface::CudaStream * newCudaStream() const =0
virtual void deleteCudaStream(CudaInterface::CudaStream *) const =0
virtual bool cudaStreamIsActive(CudaInterface::CudaStream *) const =0
void set(RooAbsArg const *arg, std::span< const double > const &span)
std::span< const double > at(RooAbsArg const *arg, RooAbsArg const *caller=nullptr)
void resetVectorBuffers()
void enableVectorBuffers(bool enable)
void setConfig(RooAbsArg const *arg, RooBatchCompute::Config const &config)
std::span< double > _currentOutput
void resize(std::size_t n)
void print(std::ostream &os)
void setClientsDirty(NodeInfo &nodeInfo)
Flags all the clients of a given node dirty.
RooArgSet getParameters() const
Gets all the parameters of the RooAbsReal.
void setOffsetMode(RooFit::EvalContext::OffsetMode)
Sets the offset mode for evaluation.
void syncDataTokens()
If there are servers with the same name that got de-duplicated in the _nodes list,...
std::unordered_map< TNamed const *, NodeInfo * > _nodesMap
std::vector< NodeInfo > _nodes
bool _needToUpdateOutputSizes
std::span< const double > getValHeterogeneous()
Returns the value of the top node in the computation graph.
std::span< const double > run()
Returns the value of the top node in the computation graph.
Evaluator(const RooAbsReal &absReal, bool useGPU=false)
Construct a new Evaluator.
void processVariable(NodeInfo &nodeInfo)
Process a variable in the computation graph.
void processCategory(NodeInfo &nodeInfo)
Process a category in the computation graph.
std::unique_ptr< RooBatchCompute::AbsBufferManager > _bufferManager
void markGPUNodes()
Decides which nodes are assigned to the GPU in a CUDA fit.
void assignToGPU(NodeInfo &info)
Assign a node to be computed in the GPU.
void setInput(std::string const &name, std::span< const double > inputArray, bool isOnDevice)
RooFit::EvalContext _evalContextCUDA
RooFit::EvalContext _evalContextCPU
void computeCPUNode(const RooAbsArg *node, NodeInfo &info)
std::stack< std::unique_ptr< ChangeOperModeRAII > > _changeOperModeRAIIs
void setOperMode(RooAbsArg *arg, RooAbsArg::OperMode opMode)
Temporarily change the operation mode of a RooAbsArg until the Evaluator gets deleted.
static RooMsgService & instance()
Return reference to singleton instance.
static const TNamed * ptr(const char *stringPtr)
Return a unique TNamed pointer for given C++ string.
Variable that can be changed from the outside.
const char * GetName() const override
Returns name of object.
virtual const char * ClassName() const
Returns name of class to which the object belongs.
R__EXTERN RooBatchComputeInterface * dispatchCUDA
std::string cpuArchitectureName()
R__EXTERN RooBatchComputeInterface * dispatchCPU
This dispatch pointer points to an implementation of the compute library, provided one has been loade...
Architecture cpuArchitecture()
int initCPU()
Inspect hardware capabilities, and load the optimal library for RooFit computations.
The namespace RooFit contains mostly switches that change the behaviour of functions of PDFs (or othe...
void getSortedComputationGraph(RooAbsArg const &func, RooArgSet &out)
A struct used by the Evaluator to store information on the RooAbsArgs in the computation graph.
RooBatchCompute::CudaInterface::CudaStream * stream
std::size_t lastSetValCount
RooBatchCompute::CudaInterface::CudaEvent * event
std::vector< NodeInfo * > serverInfos
RooAbsArg::OperMode originalOperMode
std::vector< NodeInfo * > clientInfos
std::shared_ptr< RooBatchCompute::AbsBuffer > buffer
void decrementRemainingClients()
Check the servers of a node that has been computed and release its resources if they are no longer ne...