00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #include "mat.hpp"
00019 #include "vec.hpp"
00020
00021 #include <cassert>
00022 #include <cmath>
00023
00024 static void lu(const Mat<double> &X, Mat<double> &L, Mat<double> &U,
00025 Vec<int> &p)
00026 {
00027 assert(X.ys() == X.xs());
00028
00029 int u, k, i, j, n = X.ys();
00030 double Umax;
00031
00032
00033 U = X;
00034
00035 p.setSize(n);
00036 L.setSize(n, n);
00037
00038 for(k = 0; k < n - 1; k++) {
00039
00040 u = k;
00041 Umax = fabs(U(k, k));
00042 for(i = k + 1; i < n; i++) {
00043 if(fabs(U(k, i)) > Umax) {
00044 Umax=fabs(U(k, i));
00045 u=i;
00046 }
00047 }
00048 U.swapRows(k, u);
00049 p(k) = u;
00050
00051 if(U(k, k) != 0.) {
00052
00053 for(i = k + 1; i < n; i++)
00054 U(k, i) /= U(k, k);
00055
00056
00057
00058 double *iPos = U.data() + (k + 1) * U.xs();
00059 double *kPos = U.data() + k * U.xs();
00060 for(i = k + 1; i < n; i++) {
00061 for(j = k + 1; j < n; j++) {
00062 *(iPos + j) -= *(iPos + k) * *(kPos + j);
00063 }
00064 iPos += U.xs();
00065 }
00066 }
00067
00068 }
00069
00070 p(n - 1) = n - 1;
00071
00072
00073
00074 for(i = 0; i < n; i++) {
00075 L(i, i) = 1.;
00076 for(j = i + 1; j < n; j++) {
00077 L(i, j) = U(i, j);
00078 U(i, j) = 0;
00079 L(j, i) = 0;
00080 }
00081 }
00082 }
00083
00084 static void interchangePermutations(Vec<double> &b, const Vec<int> &p)
00085 {
00086 assert(b.size() == p.size());
00087 double temp;
00088
00089 for(int k = 0; k < b.size(); k++) {
00090 SWAP(b(k), b(p(k)), temp);
00091 }
00092 }
00093
00094 static void forwardSubstitution(const Mat<double> &L, const Vec<double> &b,
00095 Vec<double> &x)
00096 {
00097 assert(L.ys() == L.xs() && L.xs() == b.size() && b.size() == x.size());
00098 int n = L.ys(), i, j, iPos;
00099 double temp;
00100
00101 x(0) = b(0) / L(0, 0);
00102 for(i = 1; i < n; i++) {
00103
00104
00105 iPos = i * L.xs();
00106 temp = 0;
00107 for(j = 0; j < i; j++) {
00108 temp += L.data()[iPos + j] * x(j);
00109 }
00110 x(i) = (b(i) - temp) / L.data()[iPos + i];
00111 }
00112 }
00113
00114 static void backwardSubstitution(const Mat<double> &U, const Vec<double> &b,
00115 Vec<double> &x)
00116 {
00117 assert(U.ys() == U.xs() && U.xs() == b.size() && b.size() == x.size());
00118 int n = U.ys(), i, j, iPos;
00119 double temp;
00120
00121 x(n - 1) = b(n - 1) / U(n - 1, n - 1);
00122 if(std::isnan(x(n - 1)))
00123 x(n - 1) = 0.;
00124 for(i = n - 2; i >= 0; i--) {
00125
00126
00127 temp = 0;
00128 iPos = i * U.xs();
00129 for(j = i + 1; j < n; j++) {
00130 temp += U.data()[iPos + j] * x(j);
00131 }
00132 x(i) = (b(i) - temp) / U.data()[iPos + i];
00133 if(std::isnan(x(i)))
00134 x(i) = 0.;
00135 }
00136 }
00137
00138 static Vec<double> lsSolve(const Mat<double> &L, const Mat<double> &U,
00139 const Vec<double> &b)
00140 {
00141 Vec<double> x(L.ys());
00142
00143 forwardSubstitution(L, b, x);
00144
00145 backwardSubstitution(U, x, x);
00146 return x;
00147 }
00148
00149 Vec<double> lsSolve(const Mat<double> &A, const Vec<double> &b)
00150 {
00151 Mat<double> L, U;
00152 Vec<int> p;
00153 Vec<double> btemp(b);
00154
00155 lu(A, L, U, p);
00156 interchangePermutations(btemp, p);
00157 return lsSolve(L, U, btemp);
00158 }