00001 /** 00002 * This stuff is shamelessly ripped from IT++. 00003 * Original copyright notice follows. 00004 * 00005 *---------------------------------------------------------------------------* 00006 * IT++ * 00007 *---------------------------------------------------------------------------* 00008 * Copyright (c) 1995-2001 by Tony Ottosson, Thomas Eriksson, Pål Frenger, * 00009 * Tobias Ringström, and Jonas Samuelsson. * 00010 * * 00011 * Permission to use, copy, modify, and distribute this software and its * 00012 * documentation under the terms of the GNU General Public License is hereby * 00013 * granted. No representations are made about the suitability of this * 00014 * software for any purpose. It is provided "as is" without expressed or * 00015 * implied warranty. See the GNU General Public License for more details. * 00016 *---------------------------------------------------------------------------*/ 00017 00018 #include "mat.hpp" 00019 #include "vec.hpp" 00020 00021 #include <cassert> 00022 #include <cmath> 00023 00024 static void lu(const Mat<double> &X, Mat<double> &L, Mat<double> &U, 00025 Vec<int> &p) 00026 { 00027 assert(X.ys() == X.xs()); 00028 00029 int u, k, i, j, n = X.ys(); 00030 double Umax; 00031 00032 // temporary matrix 00033 U = X; 00034 00035 p.setSize(n); 00036 L.setSize(n, n); 00037 00038 for(k = 0; k < n - 1; k++) { 00039 // determine u. Alt. u=max_index(abs(U(k,n-1,k,k))); 00040 u = k; 00041 Umax = fabs(U(k, k)); 00042 for(i = k + 1; i < n; i++) { 00043 if(fabs(U(k, i)) > Umax) { 00044 Umax=fabs(U(k, i)); 00045 u=i; 00046 } 00047 } 00048 U.swapRows(k, u); 00049 p(k) = u; 00050 00051 if(U(k, k) != 0.) { 00052 //U(k+1,n-1,k,k)/=U(k,k); 00053 for(i = k + 1; i < n; i++) 00054 U(k, i) /= U(k, k); 00055 // Should be: U(k+1,n-1,k+1,n-1)-=U(k+1,n-1,k,k)*U(k,k,k+1,n-1); 00056 // but this is too slow. 00057 // Instead work directly on the matrix data-structure. 00058 double *iPos = U.data() + (k + 1) * U.xs(); 00059 double *kPos = U.data() + k * U.xs(); 00060 for(i = k + 1; i < n; i++) { 00061 for(j = k + 1; j < n; j++) { 00062 *(iPos + j) -= *(iPos + k) * *(kPos + j); 00063 } 00064 iPos += U.xs(); 00065 } 00066 } 00067 00068 } 00069 00070 p(n - 1) = n - 1; 00071 00072 // Set L and reset all lower elements of U. 00073 // set all lower triangular elements to zero 00074 for(i = 0; i < n; i++) { 00075 L(i, i) = 1.; 00076 for(j = i + 1; j < n; j++) { 00077 L(i, j) = U(i, j); 00078 U(i, j) = 0; 00079 L(j, i) = 0; 00080 } 00081 } 00082 } 00083 00084 static void interchangePermutations(Vec<double> &b, const Vec<int> &p) 00085 { 00086 assert(b.size() == p.size()); 00087 double temp; 00088 00089 for(int k = 0; k < b.size(); k++) { 00090 SWAP(b(k), b(p(k)), temp); 00091 } 00092 } 00093 00094 static void forwardSubstitution(const Mat<double> &L, const Vec<double> &b, 00095 Vec<double> &x) 00096 { 00097 assert(L.ys() == L.xs() && L.xs() == b.size() && b.size() == x.size()); 00098 int n = L.ys(), i, j, iPos; 00099 double temp; 00100 00101 x(0) = b(0) / L(0, 0); 00102 for(i = 1; i < n; i++) { 00103 // Should be: x(i)=((b(i)-L(i,i,0,i-1)*x(0,i-1))/L(i,i))(0); 00104 // but this is to slow. 00105 iPos = i * L.xs(); 00106 temp = 0; 00107 for(j = 0; j < i; j++) { 00108 temp += L.data()[iPos + j] * x(j); 00109 } 00110 x(i) = (b(i) - temp) / L.data()[iPos + i]; 00111 } 00112 } 00113 00114 static void backwardSubstitution(const Mat<double> &U, const Vec<double> &b, 00115 Vec<double> &x) 00116 { 00117 assert(U.ys() == U.xs() && U.xs() == b.size() && b.size() == x.size()); 00118 int n = U.ys(), i, j, iPos; 00119 double temp; 00120 00121 x(n - 1) = b(n - 1) / U(n - 1, n - 1); 00122 if(std::isnan(x(n - 1))) 00123 x(n - 1) = 0.; 00124 for(i = n - 2; i >= 0; i--) { 00125 // Should be: x(i)=((b(i)-U(i,i,i+1,n-1)*x(i+1,n-1))/U(i,i))(0); 00126 // but this is too slow. 00127 temp = 0; 00128 iPos = i * U.xs(); 00129 for(j = i + 1; j < n; j++) { 00130 temp += U.data()[iPos + j] * x(j); 00131 } 00132 x(i) = (b(i) - temp) / U.data()[iPos + i]; 00133 if(std::isnan(x(i))) 00134 x(i) = 0.; 00135 } 00136 } 00137 00138 static Vec<double> lsSolve(const Mat<double> &L, const Mat<double> &U, 00139 const Vec<double> &b) 00140 { 00141 Vec<double> x(L.ys()); 00142 // Solve Ly=b, Here y=x 00143 forwardSubstitution(L, b, x); 00144 // Solve Ux=y, Here x=y 00145 backwardSubstitution(U, x, x); 00146 return x; 00147 } 00148 00149 Vec<double> lsSolve(const Mat<double> &A, const Vec<double> &b) 00150 { 00151 Mat<double> L, U; 00152 Vec<int> p; 00153 Vec<double> btemp(b); 00154 00155 lu(A, L, U, p); 00156 interchangePermutations(btemp, p); 00157 return lsSolve(L, U, btemp); 00158 }