ROOT logo
// @(#)root/tmva $Id: TNeuron.h 31458 2009-11-30 13:58:20Z stelzer $
// Author: Matt Jachowski

 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
 * Package: TMVA                                                                  *
 * Class  : TMVA::TNeuron                                                         *
 * Web    :                                           *
 *                                                                                *
 * Description:                                                                   *
 *      Neuron class to be used in MethodANNBase and its derivatives.             *
 *                                                                                *
 * Authors (alphabetical):                                                        *
 *      Matt Jachowski  <> - Stanford University, USA       *
 *                                                                                *
 * Copyright (c) 2005:                                                            *
 *      CERN, Switzerland                                                         *
 *                                                                                *
 * Redistribution and use in source and binary forms, with or without             *
 * modification, are permitted according to the terms listed in LICENSE           *
 * (                                          *

#ifndef ROOT_TMVA_TNeuron
#define ROOT_TMVA_TNeuron

//                                                                      //
// TNeuron                                                              //
//                                                                      //
// Neuron used by derivatives of MethodANNBase                          //
//                                                                      //

#include <iostream>

#ifndef ROOT_TString
#include "TString.h"
#ifndef ROOT_TObjArray
#include "TObjArray.h"
#ifndef ROOT_TFormula
#include "TFormula.h"

#ifndef ROOT_TMVA_TSynapse
#include "TMVA/TSynapse.h"
#ifndef ROOT_TMVA_TActivation
#include "TMVA/TActivation.h"
#ifndef ROOT_TMVA_Types
#include "TMVA/Types.h"

namespace TMVA {

   class TNeuronInput;

   class TNeuron : public TObject {


      virtual ~TNeuron();

      // force the input value
      void ForceValue(Double_t value);

      // calculate the input value
      void CalculateValue();

      // calculate the activation value
      void CalculateActivationValue();

      // calculate the error field of the neuron
      void CalculateDelta();

      // set the activation function
      void SetActivationEqn(TActivation* activation);

      // set the input calculator
      void SetInputCalculator(TNeuronInput* calculator);

      // add a synapse as a pre-link
      void AddPreLink(TSynapse* pre);

      // add a synapse as a post-link
      void AddPostLink(TSynapse* post);

      // delete all pre-links
      void DeletePreLinks();

      // set the error
      void SetError(Double_t error);

      // update the error fields of all pre-synapses, batch mode
      // to actually update the weights, call adjust synapse weights
      void UpdateSynapsesBatch();

      // update the error fields and weights of all pre-synapses, sequential mode
      void UpdateSynapsesSequential();

      // update the weights of the all pre-synapses, batch mode 
      //(call UpdateSynapsesBatch first)
      void AdjustSynapseWeights();

      // explicitly initialize error fields of pre-synapses, batch mode
      void InitSynapseDeltas();

      // print activation equation, for debugging
      void PrintActivationEqn();

      // inlined functions
      Double_t  GetValue() const                { return fValue;                          }
      Double_t  GetActivationValue() const      { return fActivationValue;                }
      Double_t  GetDelta() const                { return fDelta;                          }
      Double_t  GetDEDw() const                 { return fDEDw;                           }
      Int_t     NumPreLinks() const             { return NumLinks(fLinksIn);              }
      Int_t     NumPostLinks() const            { return NumLinks(fLinksOut);             }
      TSynapse* PreLinkAt ( Int_t index ) const { return (TSynapse*)fLinksIn->At(index);  }
      TSynapse* PostLinkAt( Int_t index ) const { return (TSynapse*)fLinksOut->At(index); }
      void      SetInputNeuron()                { NullifyLinks(fLinksIn);                 }
      void      SetOutputNeuron()               { NullifyLinks(fLinksOut);                }
      void      SetBiasNeuron()                 { NullifyLinks(fLinksIn);                 }
      void      SetDEDw( Double_t DEDw )        { fDEDw = DEDw;                           }
      Bool_t    IsInputNeuron() const           { return fLinksIn == NULL;                }
      Bool_t    IsOutputNeuron() const          { return fLinksOut == NULL;               }
      void      PrintPreLinks() const           { PrintLinks(fLinksIn); return;           }
      void      PrintPostLinks() const          { PrintLinks(fLinksOut); return;          }

      virtual void Print(Option_t* = "") const {
         std::cout << fValue << std::endl;
         //PrintPreLinks(); PrintPostLinks();


      // prviate helper functions
      void InitNeuron();
      void DeleteLinksArray( TObjArray*& links );
      void PrintLinks      ( TObjArray* links ) const;
      void PrintMessage    ( EMsgType, TString message );

      // inlined helper functions
      Int_t NumLinks(TObjArray* links) const { 
         if (links == NULL) return 0; return links->GetEntriesFast(); 
      void NullifyLinks(TObjArray*& links) { 
         if (links != NULL) delete links; links = NULL; 

      // private member variables
      TObjArray*    fLinksIn;                 // array of input synapses
      TObjArray*    fLinksOut;                // array of output synapses
      Double_t      fValue;                   // input value
      Double_t      fActivationValue;         // activation/output value
      Double_t      fDelta;                   // error field of neuron
      Double_t      fDEDw;                    // sum of all deltas
      Double_t      fError;                   // error, only set for output neurons
      Bool_t        fForcedValue;             // flag for forced input value
      TActivation*  fActivation;              // activation equation
      TNeuronInput* fInputCalculator;         // input calculator

      mutable MsgLogger* fLogger;                     //! message logger
      MsgLogger& Log() const { return *fLogger; }                       

      ClassDef(TNeuron,0) // Neuron class used by MethodANNBase derivative ANNs

} // namespace TMVA