#include "TDecompChol.h"
#include "TMath.h"
ClassImp(TDecompChol)
TDecompChol::TDecompChol(Int_t nrows)
{
fU.ResizeTo(nrows,nrows);
}
TDecompChol::TDecompChol(Int_t row_lwb,Int_t row_upb)
{
const Int_t nrows = row_upb-row_lwb+1;
fRowLwb = row_lwb;
fColLwb = row_lwb;
fU.ResizeTo(row_lwb,row_lwb+nrows-1,row_lwb,row_lwb+nrows-1);
}
TDecompChol::TDecompChol(const TMatrixDSym &a,Double_t tol)
{
R__ASSERT(a.IsValid());
SetBit(kMatrixSet);
fCondition = a.Norm1();
fTol = a.GetTol();
if (tol > 0)
fTol = tol;
fRowLwb = a.GetRowLwb();
fColLwb = a.GetColLwb();
fU.ResizeTo(a);
fU = a;
}
TDecompChol::TDecompChol(const TMatrixD &a,Double_t tol)
{
R__ASSERT(a.IsValid());
if (a.GetNrows() != a.GetNcols() || a.GetRowLwb() != a.GetColLwb()) {
Error("TDecompChol(const TMatrixD &","matrix should be square");
return;
}
SetBit(kMatrixSet);
fCondition = a.Norm1();
fTol = a.GetTol();
if (tol > 0)
fTol = tol;
fRowLwb = a.GetRowLwb();
fColLwb = a.GetColLwb();
fU.ResizeTo(a);
fU = a;
}
TDecompChol::TDecompChol(const TDecompChol &another) : TDecompBase(another)
{
*this = another;
}
Bool_t TDecompChol::Decompose()
{
if (TestBit(kDecomposed)) return kTRUE;
if ( !TestBit(kMatrixSet) ) {
Error("Decompose()","Matrix has not been set");
return kFALSE;
}
Int_t i,j,icol,irow;
const Int_t n = fU.GetNrows();
Double_t *pU = fU.GetMatrixArray();
for (icol = 0; icol < n; icol++) {
const Int_t rowOff = icol*n;
Double_t ujj = pU[rowOff+icol];
for (irow = 0; irow < icol; irow++) {
const Int_t pos_ij = irow*n+icol;
ujj -= pU[pos_ij]*pU[pos_ij];
}
if (ujj <= 0) {
Error("Decompose()","matrix not positive definite");
return kFALSE;
}
ujj = TMath::Sqrt(ujj);
pU[rowOff+icol] = ujj;
if (icol < n-1) {
for (j = icol+1; j < n; j++) {
for (i = 0; i < icol; i++) {
const Int_t rowOff2 = i*n;
pU[rowOff+j] -= pU[rowOff2+j]*pU[rowOff2+icol];
}
}
for (j = icol+1; j < n; j++)
pU[rowOff+j] /= ujj;
}
}
for (irow = 0; irow < n; irow++) {
const Int_t rowOff = irow*n;
for (icol = 0; icol < irow; icol++)
pU[rowOff+icol] = 0.;
}
SetBit(kDecomposed);
return kTRUE;
}
const TMatrixDSym TDecompChol::GetMatrix()
{
if (TestBit(kSingular)) {
Error("GetMatrix()","Matrix is singular");
return TMatrixDSym();
}
if ( !TestBit(kDecomposed) ) {
if (!Decompose()) {
Error("GetMatrix()","Decomposition failed");
return TMatrixDSym();
}
}
return TMatrixDSym(TMatrixDSym::kAtA,fU);
}
void TDecompChol::SetMatrix(const TMatrixDSym &a)
{
R__ASSERT(a.IsValid());
ResetStatus();
if (a.GetNrows() != a.GetNcols() || a.GetRowLwb() != a.GetColLwb()) {
Error("SetMatrix(const TMatrixDSym &","matrix should be square");
return;
}
SetBit(kMatrixSet);
fCondition = -1.0;
fRowLwb = a.GetRowLwb();
fColLwb = a.GetColLwb();
fU.ResizeTo(a);
fU = a;
}
Bool_t TDecompChol::Solve(TVectorD &b)
{
R__ASSERT(b.IsValid());
if (TestBit(kSingular)) {
Error("Solve()","Matrix is singular");
return kFALSE;
}
if ( !TestBit(kDecomposed) ) {
if (!Decompose()) {
Error("Solve()","Decomposition failed");
return kFALSE;
}
}
if (fU.GetNrows() != b.GetNrows() || fU.GetRowLwb() != b.GetLwb()) {
Error("Solve(TVectorD &","vector and matrix incompatible");
return kFALSE;
}
const Int_t n = fU.GetNrows();
const Double_t *pU = fU.GetMatrixArray();
Double_t *pb = b.GetMatrixArray();
Int_t i;
for (i = 0; i < n; i++) {
const Int_t off_i = i*n;
if (pU[off_i+i] < fTol)
{
Error("Solve(TVectorD &b)","u[%d,%d]=%.4e < %.4e",i,i,pU[off_i+i],fTol);
return kFALSE;
}
Double_t r = pb[i];
for (Int_t j = 0; j < i; j++) {
const Int_t off_j = j*n;
r -= pU[off_j+i]*pb[j];
}
pb[i] = r/pU[off_i+i];
}
for (i = n-1; i >= 0; i--) {
const Int_t off_i = i*n;
Double_t r = pb[i];
for (Int_t j = i+1; j < n; j++)
r -= pU[off_i+j]*pb[j];
pb[i] = r/pU[off_i+i];
}
return kTRUE;
}
Bool_t TDecompChol::Solve(TMatrixDColumn &cb)
{
TMatrixDBase *b = const_cast<TMatrixDBase *>(cb.GetMatrix());
R__ASSERT(b->IsValid());
if (TestBit(kSingular)) {
Error("Solve()","Matrix is singular");
return kFALSE;
}
if ( !TestBit(kDecomposed) ) {
if (!Decompose()) {
Error("Solve()","Decomposition failed");
return kFALSE;
}
}
if (fU.GetNrows() != b->GetNrows() || fU.GetRowLwb() != b->GetRowLwb())
{
Error("Solve(TMatrixDColumn &cb","vector and matrix incompatible");
return kFALSE;
}
const Int_t n = fU.GetNrows();
const Double_t *pU = fU.GetMatrixArray();
Double_t *pcb = cb.GetPtr();
const Int_t inc = cb.GetInc();
Int_t i;
for (i = 0; i < n; i++) {
const Int_t off_i = i*n;
const Int_t off_i2 = i*inc;
if (pU[off_i+i] < fTol)
{
Error("Solve(TMatrixDColumn &cb)","u[%d,%d]=%.4e < %.4e",i,i,pU[off_i+i],fTol);
return kFALSE;
}
Double_t r = pcb[off_i2];
for (Int_t j = 0; j < i; j++) {
const Int_t off_j = j*n;
r -= pU[off_j+i]*pcb[j*inc];
}
pcb[off_i2] = r/pU[off_i+i];
}
for (i = n-1; i >= 0; i--) {
const Int_t off_i = i*n;
const Int_t off_i2 = i*inc;
Double_t r = pcb[off_i2];
for (Int_t j = i+1; j < n; j++)
r -= pU[off_i+j]*pcb[j*inc];
pcb[off_i2] = r/pU[off_i+i];
}
return kTRUE;
}
void TDecompChol::Det(Double_t &d1,Double_t &d2)
{
if ( !TestBit(kDetermined) ) {
if ( !TestBit(kDecomposed) )
Decompose();
TDecompBase::Det(d1,d2);
fDet1 *= fDet1;
fDet2 += fDet2;
SetBit(kDetermined);
}
d1 = fDet1;
d2 = fDet2;
}
Bool_t TDecompChol::Invert(TMatrixDSym &inv)
{
if (inv.GetNrows() != GetNrows() || inv.GetRowLwb() != GetRowLwb()) {
Error("Invert(TMatrixDSym &","Input matrix has wrong shape");
return kFALSE;
}
inv.UnitMatrix();
const Int_t colLwb = inv.GetColLwb();
const Int_t colUpb = inv.GetColUpb();
Bool_t status = kTRUE;
for (Int_t icol = colLwb; icol <= colUpb && status; icol++) {
TMatrixDColumn b(inv,icol);
status &= Solve(b);
}
return status;
}
TMatrixDSym TDecompChol::Invert(Bool_t &status)
{
const Int_t rowLwb = GetRowLwb();
const Int_t rowUpb = rowLwb+GetNrows()-1;
TMatrixDSym inv(rowLwb,rowUpb);
inv.UnitMatrix();
status = Invert(inv);
return inv;
}
void TDecompChol::Print(Option_t *opt) const
{
TDecompBase::Print(opt);
fU.Print("fU");
}
TDecompChol &TDecompChol::operator=(const TDecompChol &source)
{
if (this != &source) {
TDecompBase::operator=(source);
fU.ResizeTo(source.fU);
fU = source.fU;
}
return *this;
}
TVectorD NormalEqn(const TMatrixD &A,const TVectorD &b)
{
TDecompChol ch(TMatrixDSym(TMatrixDSym::kAtA,A));
Bool_t ok;
return ch.Solve(TMatrixD(TMatrixD::kTransposed,A)*b,ok);
}
TVectorD NormalEqn(const TMatrixD &A,const TVectorD &b,const TVectorD &std)
{
if (!AreCompatible(b,std)) {
::Error("NormalEqn","vectors b and std are not compatible");
return TVectorD();
}
TMatrixD mAw = A;
TVectorD mBw = b;
for (Int_t irow = 0; irow < A.GetNrows(); irow++) {
TMatrixDRow(mAw,irow) *= 1/std(irow);
mBw(irow) /= std(irow);
}
TDecompChol ch(TMatrixDSym(TMatrixDSym::kAtA,mAw));
Bool_t ok;
return ch.Solve(TMatrixD(TMatrixD::kTransposed,mAw)*mBw,ok);
}
TMatrixD NormalEqn(const TMatrixD &A,const TMatrixD &B)
{
TDecompChol ch(TMatrixDSym(TMatrixDSym::kAtA,A));
TMatrixD mX(A,TMatrixD::kTransposeMult,B);
ch.MultiSolve(mX);
return mX;
}
TMatrixD NormalEqn(const TMatrixD &A,const TMatrixD &B,const TVectorD &std)
{
TMatrixD mAw = A;
TMatrixD mBw = B;
for (Int_t irow = 0; irow < A.GetNrows(); irow++) {
TMatrixDRow(mAw,irow) *= 1/std(irow);
TMatrixDRow(mBw,irow) *= 1/std(irow);
}
TDecompChol ch(TMatrixDSym(TMatrixDSym::kAtA,mAw));
TMatrixD mX(mAw,TMatrixD::kTransposeMult,mBw);
ch.MultiSolve(mX);
return mX;
}