Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
RSampler.hxx
Go to the documentation of this file.
1// Author: Martin Føll, University of Oslo (UiO) & CERN 01/2026
2
3/*************************************************************************
4 * Copyright (C) 1995-2026, Rene Brun and Fons Rademakers. *
5 * All rights reserved. *
6 * *
7 * For the licensing terms see $ROOTSYS/LICENSE. *
8 * For the list of contributors see $ROOTSYS/README/CREDITS. *
9 *************************************************************************/
10
11#ifndef TMVA_RSAMPLER
12#define TMVA_RSAMPLER
13
14#include <vector>
15#include <random>
16#include <algorithm>
17
18#include "ROOT/RDataFrame.hxx"
19#include "ROOT/RDF/Utils.hxx"
20#include "ROOT/RVec.hxx"
22#include "ROOT/RLogger.hxx"
23
25// clang-format off
26/**
27\class ROOT::TMVA::Experimental::Internal::RSampler
28\ingroup tmva
29\brief Implementation of different sampling strategies.
30*/
31
32class RSampler {
33private:
34 // clang-format on
35 std::vector<RFlat2DMatrix> &fDatasets;
36 std::string fSampleType;
40 std::size_t fSetSeed;
41 std::size_t fNumEntries;
42
43 std::size_t fMajor;
44 std::size_t fMinor;
45 std::size_t fNumMajor;
46 std::size_t fNumMinor;
47 std::size_t fNumResampledMajor;
48 std::size_t fNumResampledMinor;
49
50 std::vector<std::size_t> fSamples;
51
52 std::unique_ptr<RFlat2DMatrixOperators> fTensorOperators;
53public:
54 RSampler(std::vector<RFlat2DMatrix> &datasets, const std::string &sampleType, float sampleRatio,
55 bool replacement = false, bool shuffle = true, std::size_t setSeed = 0)
56 : fDatasets(datasets),
62 {
63 fTensorOperators = std::make_unique<RFlat2DMatrixOperators>(fShuffle, fSetSeed);
64
65 // setup the sampler for the datasets
66 SetupSampler();
67 }
68
69 //////////////////////////////////////////////////////////////////////////
70 /// \brief Calculate fNumEntries and major/minor variables
72 {
73 if (fSampleType == "undersampling") {
75 }
76 else if (fSampleType == "oversampling") {
78 }
79 }
80
81 //////////////////////////////////////////////////////////////////////////
82 /// \brief Collection of sampling types
83 /// \param[in] SampledTensor Tensor with all the sampled entries
85 {
86 if (fSampleType == "undersampling") {
88 }
89 else if (fSampleType == "oversampling") {
91 }
92 }
93
94 //////////////////////////////////////////////////////////////////////////
95 /// \brief Calculate fNumEntries and major/minor variables for the random undersampler
97 {
98 if (fDatasets[0].GetRows() > fDatasets[1].GetRows()) {
99 fMajor = 0;
100 fMinor = 1;
101 }
102 else {
103 fMajor = 1;
104 fMinor = 0;
105 }
106
107 fNumMajor = fDatasets[fMajor].GetRows();
108 fNumMinor = fDatasets[fMinor].GetRows();
109 fNumResampledMajor = static_cast<std::size_t>(fNumMinor / fSampleRatio);
111 }
112
113 //////////////////////////////////////////////////////////////////////////
114 /// \brief Calculate fNumEntries and major/minor variables for the random oversampler
116 {
117 if (fDatasets[0].GetRows() > fDatasets[1].GetRows()) {
118 fMajor = 0;
119 fMinor = 1;
120 }
121 else {
122 fMajor = 1;
123 fMinor = 0;
124 }
125
126 fNumMajor = fDatasets[fMajor].GetRows();
127 fNumMinor = fDatasets[fMinor].GetRows();
128 fNumResampledMinor = static_cast<std::size_t>(fSampleRatio * fNumMajor);
130 }
131
132 //////////////////////////////////////////////////////////////////////////
133 /// \brief Undersample entries randomly from the majority dataset
134 /// \param[in] SampledTensor Tensor with all the sampled entries
136 {
137 if (fReplacement) {
139 }
140
141 else {
143 }
144
145 std::size_t cols = fDatasets[0].GetCols();
149
150 std::size_t index = 0;
151 for (std::size_t i = 0; i < fNumResampledMajor; i++) {
152 std::copy(fDatasets[fMajor].GetData() + fSamples[i] * cols, fDatasets[fMajor].GetData() + (fSamples[i]+1) * cols,
153 UndersampledMajorTensor.GetData() + index * cols);
154 index++;
155 }
156
159 }
160
161 //////////////////////////////////////////////////////////////////////////
162 /// \brief Oversample entries randomly from the minority dataset
163 /// \param[in] SampledTensor Tensor with all the sampled entries
165 {
167
168 std::size_t cols = fDatasets[0].GetCols();
172
173 std::size_t index = 0;
174 for (std::size_t i = 0; i < fNumResampledMinor; i++) {
175 std::copy(fDatasets[fMinor].GetData() + fSamples[i] * cols, fDatasets[fMinor].GetData() + (fSamples[i]+1) * cols,
176 OversampledMinorTensor.GetData() + index * cols);
177 index++;
178 }
179
182 }
183
184 //////////////////////////////////////////////////////////////////////////
185 /// \brief Add indices with replacement to fSamples
186 /// \param[in] n_samples Number of indices to sample
187 /// \param[in] max Max index of the sample distribution
188 void SampleWithReplacement(std::size_t n_samples, std::size_t max)
189 {
190 std::uniform_int_distribution<> dist(0, max - 1);
191 fSamples.clear();
192 fSamples.reserve(n_samples);
193 for (std::size_t i = 0; i < n_samples; ++i) {
194 std::size_t sample;
195 if (fShuffle) {
196 std::random_device rd;
197 std::mt19937 g;
198
199 if (fSetSeed == 0) {
200 g.seed(rd());
201 } else {
202 g.seed(fSetSeed);
203 }
204
205 sample = dist(g);
206 }
207
208 else {
209 sample = i % max;
210 }
211 fSamples.push_back(sample);
212 }
213 }
214
215 //////////////////////////////////////////////////////////////////////////
216 /// \brief Add indices without replacement to fSamples
217 /// \param[in] n_samples Number of indices to sample
218 /// \param[in] max Max index of the sample distribution
219 void SampleWithoutReplacement(std::size_t n_samples, std::size_t max)
220 {
221 std::vector<std::size_t> UniqueSamples;
222 UniqueSamples.reserve(max);
223 fSamples.clear();
224 fSamples.reserve(n_samples);
225
226 for (std::size_t i = 0; i < max; ++i)
227 UniqueSamples.push_back(i);
228
229 if (fShuffle) {
230 std::random_device rd;
231 std::mt19937 g;
232
233 if (fSetSeed == 0) {
234 g.seed(rd());
235 } else {
236 g.seed(fSetSeed);
237 }
238 std::shuffle(UniqueSamples.begin(), UniqueSamples.end(), g);
239 }
240
241 for (std::size_t i = 0; i < n_samples; ++i) {
242 fSamples.push_back(UniqueSamples[i]);
243 }
244 }
245
246 std::size_t GetNumEntries() { return fNumEntries;}
247};
248
249} // namespace TMVA::Experimental::Internal
250#endif // TMVA_RSAMPLER
#define g(i)
Definition RSha256.hxx:105
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t index
const_iterator begin() const
const_iterator end() const
void RandomOversampler(RFlat2DMatrix &ShuffledTensor)
Oversample entries randomly from the minority dataset.
Definition RSampler.hxx:164
void RandomUndersampler(RFlat2DMatrix &ShuffledTensor)
Undersample entries randomly from the majority dataset.
Definition RSampler.hxx:135
RSampler(std::vector< RFlat2DMatrix > &datasets, const std::string &sampleType, float sampleRatio, bool replacement=false, bool shuffle=true, std::size_t setSeed=0)
Definition RSampler.hxx:54
std::vector< RFlat2DMatrix > & fDatasets
Definition RSampler.hxx:35
void SetupRandomOversampler()
Calculate fNumEntries and major/minor variables for the random oversampler.
Definition RSampler.hxx:115
void Sampler(RFlat2DMatrix &SampledTensor)
Collection of sampling types.
Definition RSampler.hxx:84
void SetupSampler()
Calculate fNumEntries and major/minor variables.
Definition RSampler.hxx:71
void SampleWithoutReplacement(std::size_t n_samples, std::size_t max)
Add indices without replacement to fSamples.
Definition RSampler.hxx:219
void SampleWithReplacement(std::size_t n_samples, std::size_t max)
Add indices with replacement to fSamples.
Definition RSampler.hxx:188
std::unique_ptr< RFlat2DMatrixOperators > fTensorOperators
Definition RSampler.hxx:52
std::vector< std::size_t > fSamples
Definition RSampler.hxx:50
void SetupRandomUndersampler()
Calculate fNumEntries and major/minor variables for the random undersampler.
Definition RSampler.hxx:96
Wrapper around ROOT::RVec<float> representing a 2D matrix.