Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
DataLoader.cxx
Go to the documentation of this file.
1// @(#)root/tmva/tmva/dnn:$Id$
2// Author: Simon Pfreundschuh 06/06/17
3
4/*************************************************************************
5 * Copyright (C) 2016, Simon Pfreundschuh *
6 * All rights reserved. *
7 * *
8 * For the licensing terms see $ROOTSYS/LICENSE. *
9 * For the list of contributors see $ROOTSYS/README/CREDITS. *
10 *************************************************************************/
11
12/////////////////////////////////////////////////////////////
13// Specializations of Copy functions for the DataLoader //
14// specialized for the reference architecture. //
15/////////////////////////////////////////////////////////////
16
18#include "TMVA/DataSetInfo.h"
19
20namespace TMVA {
21namespace DNN {
22
23//______________________________________________________________________________
24template <>
26{
27 const TMatrixT<Real_t> &input = std::get<0>(fData);
28 Int_t m = matrix.GetNrows();
29 Int_t n = input.GetNcols();
30
31 for (Int_t i = 0; i < m; i++) {
32 Int_t sampleIndex = *sampleIterator;
33 for (Int_t j = 0; j < n; j++) {
34 matrix(i, j) = static_cast<Real_t>(input(sampleIndex, j));
35 }
36 sampleIterator++;
37 }
38}
39
40//______________________________________________________________________________
41template <>
43 IndexIterator_t sampleIterator)
44{
45 const TMatrixT<Real_t> &output = std::get<1>(fData);
46 Int_t m = matrix.GetNrows();
47 Int_t n = output.GetNcols();
48
49 for (Int_t i = 0; i < m; i++) {
50 Int_t sampleIndex = *sampleIterator;
51 for (Int_t j = 0; j < n; j++) {
52 matrix(i, j) = static_cast<Real_t>(output(sampleIndex, j));
53 }
54 sampleIterator++;
55 }
56}
57
58//______________________________________________________________________________
59template <>
61 IndexIterator_t sampleIterator)
62{
63 const TMatrixT<Real_t> &weights = std::get<2>(fData);
64 Int_t m = matrix.GetNrows();
65
66 for (Int_t i = 0; i < m; i++) {
67 Int_t sampleIndex = *sampleIterator;
68 matrix(i, 0) = static_cast<Real_t>(weights(sampleIndex, 0));
69 sampleIterator++;
70 }
71}
72
73//______________________________________________________________________________
74template <>
76 IndexIterator_t sampleIterator)
77{
78 const TMatrixT<Double_t> &input = std::get<0>(fData);
79 Int_t m = matrix.GetNrows();
80 Int_t n = input.GetNcols();
81
82 for (Int_t i = 0; i < m; i++) {
83 Int_t sampleIndex = *sampleIterator;
84 for (Int_t j = 0; j < n; j++) {
85 matrix(i, j) = static_cast<Double_t>(input(sampleIndex, j));
86 }
87 sampleIterator++;
88 }
89}
90
91//______________________________________________________________________________
92template <>
94 IndexIterator_t sampleIterator)
95{
96 const TMatrixT<Double_t> &output = std::get<1>(fData);
97 Int_t m = matrix.GetNrows();
98 Int_t n = output.GetNcols();
99
100 for (Int_t i = 0; i < m; i++) {
101 Int_t sampleIndex = *sampleIterator;
102 for (Int_t j = 0; j < n; j++) {
103 matrix(i, j) = static_cast<Double_t>(output(sampleIndex, j));
104 }
105 sampleIterator++;
106 }
107}
108
109//______________________________________________________________________________
110template <>
112 IndexIterator_t sampleIterator)
113{
114 const TMatrixT<Double_t> &output = std::get<2>(fData);
115 Int_t m = matrix.GetNrows();
116
117 for (Int_t i = 0; i < m; i++) {
118 Int_t sampleIndex = *sampleIterator;
119 matrix(i, 0) = static_cast<Double_t>(output(sampleIndex, 0));
120 sampleIterator++;
121 }
122}
123
124//______________________________________________________________________________
125template <>
127{
128 Event *event = nullptr;
129
130 Int_t m = matrix.GetNrows();
131 Int_t n = event->GetNVariables();
132
133 // Copy input variables.
134
135 for (Int_t i = 0; i < m; i++) {
136 Int_t sampleIndex = *sampleIterator++;
137 event = std::get<0>(fData)[sampleIndex];
138 for (Int_t j = 0; j < n; j++) {
139 matrix(i, j) = event->GetValue(j);
140 }
141 }
142}
143
144//______________________________________________________________________________
145template <>
147{
148 Event *event = std::get<0>(fData).front();
149 const DataSetInfo &info = std::get<1>(fData);
150 Int_t m = matrix.GetNrows();
151 Int_t n = matrix.GetNcols();
152
153 for (Int_t i = 0; i < m; i++) {
154 Int_t sampleIndex = *sampleIterator++;
155 event = std::get<0>(fData)[sampleIndex];
156 for (Int_t j = 0; j < n; j++) {
157 // Classification
158 if (event->GetNTargets() == 0) {
159 if (n == 1) {
160 // Binary.
161 matrix(i, j) = (info.IsSignal(event)) ? 1.0 : 0.0;
162 } else {
163 // Multiclass.
164 matrix(i, j) = 0.0;
165 if (j == static_cast<Int_t>(event->GetClass())) {
166 matrix(i, j) = 1.0;
167 }
168 }
169 } else {
170 matrix(i, j) = static_cast<Real_t>(event->GetTarget(j));
171 }
172 }
173 }
174}
175
176//______________________________________________________________________________
177template <>
179{
180 Event *event = std::get<0>(fData).front();
181 for (Int_t i = 0; i < matrix.GetNrows(); i++) {
182 Int_t sampleIndex = *sampleIterator++;
183 event = std::get<0>(fData)[sampleIndex];
184 matrix(i, 0) = event->GetWeight();
185 }
186}
187
188//______________________________________________________________________________
189template <>
191 IndexIterator_t sampleIterator)
192{
193 Event *event = std::get<0>(fData).front();
194 Int_t m = matrix.GetNrows();
195 Int_t n = event->GetNVariables();
196
197 // Copy input variables.
198
199 for (Int_t i = 0; i < m; i++) {
200 Int_t sampleIndex = *sampleIterator++;
201 event = std::get<0>(fData)[sampleIndex];
202 for (Int_t j = 0; j < n; j++) {
203 matrix(i, j) = event->GetValue(j);
204 }
205 }
206}
207
208//______________________________________________________________________________
209template <>
211 IndexIterator_t sampleIterator)
212{
213 Event *event = std::get<0>(fData).front();
214 const DataSetInfo &info = std::get<1>(fData);
215 Int_t m = matrix.GetNrows();
216 Int_t n = matrix.GetNcols();
217
218 for (Int_t i = 0; i < m; i++) {
219 Int_t sampleIndex = *sampleIterator++;
220 event = std::get<0>(fData)[sampleIndex];
221 for (Int_t j = 0; j < n; j++) {
222 // Classification
223 if (event->GetNTargets() == 0) {
224 if (n == 1) {
225 // Binary.
226 matrix(i, j) = (info.IsSignal(event)) ? 1.0 : 0.0;
227 } else {
228 // Multiclass.
229 matrix(i, j) = 0.0;
230 if (j == static_cast<Int_t>(event->GetClass())) {
231 matrix(i, j) = 1.0;
232 }
233 }
234 } else {
235 matrix(i, j) = static_cast<Real_t>(event->GetTarget(j));
236 }
237 }
238 }
239}
240
241//______________________________________________________________________________
242template <>
244 IndexIterator_t sampleIterator)
245{
246 Event *event = nullptr;
247
248 for (Int_t i = 0; i < matrix.GetNrows(); i++) {
249 Int_t sampleIndex = *sampleIterator++;
250 event = std::get<0>(fData)[sampleIndex];
251 matrix(i, 0) = event->GetWeight();
252 }
253}
254
255// Explicit instantiations.
260
261} // namespace DNN
262} // namespace TMVA
float Real_t
Definition RtypesCore.h:68
double Double_t
Definition RtypesCore.h:59
Class that contains all the data information.
Definition DataSetInfo.h:62
Bool_t IsSignal(const Event *ev) const
Int_t GetNrows() const
Int_t GetNcols() const
TMatrixT.
Definition TMatrixT.h:39
const Int_t n
Definition legend1.C:16
typename std::vector< size_t >::iterator IndexIterator_t
Definition DataLoader.h:42
create variable transformations
auto * m
Definition textangle.C:8
static void output(int code)
Definition gifencode.c:226