ergo
mat_utils.h
Go to the documentation of this file.
1 /* Ergo, version 3.4, a program for linear scaling electronic structure
2  * calculations.
3  * Copyright (C) 2014 Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek.
4  *
5  * This program is free software: you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License as published by
7  * the Free Software Foundation, either version 3 of the License, or
8  * (at your option) any later version.
9  *
10  * This program is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  * GNU General Public License for more details.
14  *
15  * You should have received a copy of the GNU General Public License
16  * along with this program. If not, see <http://www.gnu.org/licenses/>.
17  *
18  * Primary academic reference:
19  * Kohn−Sham Density Functional Theory Electronic Structure Calculations
20  * with Linearly Scaling Computational Time and Memory Usage,
21  * Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek,
22  * J. Chem. Theory Comput. 7, 340 (2011),
23  * <http://dx.doi.org/10.1021/ct100611z>
24  *
25  * For further information about Ergo, see <http://www.ergoscf.org>.
26  */
27 #ifndef MAT_UTILS_HEADER
28 #define MAT_UTILS_HEADER
29 #include "Interval.h"
30 #include "matrix_proxy.h"
31 namespace mat {
32 
33  template<typename Tmatrix, typename Treal>
34  struct DiffMatrix {
35  typedef typename Tmatrix::VectorType VectorType;
36  void getCols(SizesAndBlocks & colsCopy) const {
37  A.getCols(colsCopy);
38  }
39  int get_nrows() const {
40  assert( A.get_nrows() == B.get_nrows() );
41  return A.get_nrows();
42  }
43  Treal frob() const {
44  return Tmatrix::frob_diff(A, B);
45  }
46  void quickEuclBounds(Treal & euclLowerBound,
47  Treal & euclUpperBound) const {
48  Treal frobTmp = frob();
49  euclLowerBound = frobTmp / template_blas_sqrt( (Treal)get_nrows() );
50  euclUpperBound = frobTmp;
51  }
52 
53  Tmatrix const & A;
54  Tmatrix const & B;
55  DiffMatrix(Tmatrix const & A_, Tmatrix const & B_)
56  : A(A_), B(B_) {}
57  template<typename Tvector>
58  void matVecProd(Tvector & y, Tvector const & x) const {
59  Tvector tmp(y);
60  tmp = (Treal)-1.0 * B * x; // -B * x
61  y = (Treal)1.0 * A * x; // A * x
62  y += (Treal)1.0 * tmp; // A * x - B * x => (A - B) * x
63  }
64  };
65 
66 
67  // ATAMatrix AT*A
68  template<typename Tmatrix, typename Treal>
69  struct ATAMatrix {
70  typedef typename Tmatrix::VectorType VectorType;
71  Tmatrix const & A;
72  explicit ATAMatrix(Tmatrix const & A_)
73  : A(A_) {}
74  void getCols(SizesAndBlocks & colsCopy) const {
75  A.getRows(colsCopy);
76  }
77  void quickEuclBounds(Treal & euclLowerBound,
78  Treal & euclUpperBound) const {
79  Treal frobA = A.frob();
80  euclLowerBound = 0;
81  euclUpperBound = frobA * frobA;
82  }
83 
84  // y = AT*A*x
85  template<typename Tvector>
86  void matVecProd(Tvector & y, Tvector const & x) const {
87  y = x;
88  y = A * y;
89  y = transpose(A) * y;
90  }
91  // Number of rows of A^T * A is the number of columns of A
92  int get_nrows() const { return A.get_ncols(); }
93  };
94 
95 
96  template<typename Tmatrix, typename Tmatrix2, typename Treal>
97  struct TripleMatrix {
98  typedef typename Tmatrix::VectorType VectorType;
99  void getCols(SizesAndBlocks & colsCopy) const {
100  A.getCols(colsCopy);
101  }
102  int get_nrows() const {
103  assert( A.get_nrows() == Z.get_nrows() );
104  return A.get_nrows();
105  }
106  void quickEuclBounds(Treal & euclLowerBound,
107  Treal & euclUpperBound) const {
108  Treal frobA = A.frob();
109  Treal frobZ = Z.frob();
110  euclLowerBound = 0;
111  euclUpperBound = frobA * frobZ * frobZ;
112  }
113 
114  Tmatrix const & A;
115  Tmatrix2 const & Z;
116  TripleMatrix(Tmatrix const & A_, Tmatrix2 const & Z_)
117  : A(A_), Z(Z_) {}
118  void matVecProd(VectorType & y, VectorType const & x) const {
119  VectorType tmp(x);
120  tmp = Z * tmp; // Z * x
121  y = (Treal)1.0 * A * tmp; // A * Z * x
122  y = transpose(Z) * y; // Z^T * A * Z * x
123  }
124  };
125 
126 
127  template<typename Tmatrix, typename Tmatrix2, typename Treal>
129  typedef typename Tmatrix::VectorType VectorType;
130  void getCols(SizesAndBlocks & colsCopy) const {
131  E.getRows(colsCopy);
132  }
133  int get_nrows() const {
134  return E.get_ncols();
135  }
136  void quickEuclBounds(Treal & euclLowerBound,
137  Treal & euclUpperBound) const {
138  Treal frobA = A.frob();
139  Treal frobZ = Zt.frob();
140  Treal frobE = E.frob();
141  euclLowerBound = 0;
142  euclUpperBound = frobA * frobE * frobE + 2 * frobA * frobE * frobZ;
143  }
144 
145  Tmatrix const & A;
146  Tmatrix2 const & Zt;
147  Tmatrix2 const & E;
148 
149  CongrTransErrorMatrix(Tmatrix const & A_,
150  Tmatrix2 const & Z_,
151  Tmatrix2 const & E_)
152  : A(A_), Zt(Z_), E(E_) {}
153  void matVecProd(VectorType & y, VectorType const & x) const {
154 
155  VectorType tmp(x);
156  tmp = E * tmp; // E * x
157  y = (Treal)-1.0 * A * tmp; // -A * E * x
158  y = transpose(E) * y; // -E^T * A * E * x
159 
160  VectorType tmp1;
161  tmp = x;
162  tmp = Zt * tmp; // Zt * x
163  tmp1 = (Treal)1.0 * A * tmp; // A * Zt * x
164  tmp1 = transpose(E) * tmp1; // E^T * A * Zt * x
165  y += (Treal)1.0 * tmp1;
166 
167  tmp = x;
168  tmp = E * tmp; // E * x
169  tmp1 = (Treal)1.0 * A * tmp; // A * E * x
170  tmp1 = transpose(Zt) * tmp1; // Zt^T * A * E * x
171  y += (Treal)1.0 * tmp1;
172  }
173  };
174 
175 
176 
177 } /* end namespace mat */
178 #endif
TripleMatrix(Tmatrix const &A_, Tmatrix2 const &Z_)
Definition: mat_utils.h:116
Tmatrix const & B
Definition: mat_utils.h:54
int get_nrows() const
Definition: mat_utils.h:92
void matVecProd(Tvector &y, Tvector const &x) const
Definition: mat_utils.h:58
DiffMatrix(Tmatrix const &A_, Tmatrix const &B_)
Definition: mat_utils.h:55
void matVecProd(Tvector &y, Tvector const &x) const
Definition: mat_utils.h:86
Proxy structs used by the matrix API.
Tmatrix::VectorType VectorType
Definition: mat_utils.h:70
Tmatrix2 const & Zt
Definition: mat_utils.h:146
Tmatrix::VectorType VectorType
Definition: mat_utils.h:35
Tmatrix2 const & Z
Definition: mat_utils.h:115
int get_nrows() const
Definition: mat_utils.h:39
void matVecProd(VectorType &y, VectorType const &x) const
Definition: mat_utils.h:118
int get_nrows() const
Definition: mat_utils.h:133
void getCols(SizesAndBlocks &colsCopy) const
Definition: mat_utils.h:74
Definition: allocate.cc:30
Describes dimensions of matrix and its blocks on all levels.
Definition: SizesAndBlocks.h:37
CongrTransErrorMatrix(Tmatrix const &A_, Tmatrix2 const &Z_, Tmatrix2 const &E_)
Definition: mat_utils.h:149
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition: mat_utils.h:136
Definition: mat_utils.h:97
Tmatrix const & A
Definition: mat_utils.h:114
void getCols(SizesAndBlocks &colsCopy) const
Definition: mat_utils.h:36
Tmatrix const & A
Definition: mat_utils.h:53
Definition: mat_utils.h:128
void matVecProd(VectorType &y, VectorType const &x) const
Definition: mat_utils.h:153
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition: mat_utils.h:106
Definition: mat_utils.h:69
Tmatrix2 const & E
Definition: mat_utils.h:147
Definition: mat_utils.h:34
void getCols(SizesAndBlocks &colsCopy) const
Definition: mat_utils.h:130
Tmatrix const & A
Definition: mat_utils.h:71
Treal frob() const
Definition: mat_utils.h:43
Tmatrix::VectorType VectorType
Definition: mat_utils.h:129
Tmatrix::VectorType VectorType
Definition: mat_utils.h:98
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition: mat_utils.h:77
Tmatrix const & A
Definition: mat_utils.h:145
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition: mat_utils.h:46
Interval class.
Xtrans< TX > transpose(TX const &A)
Transposition.
Definition: matrix_proxy.h:129
int get_nrows() const
Definition: mat_utils.h:102
ATAMatrix(Tmatrix const &A_)
Definition: mat_utils.h:72
Treal template_blas_sqrt(Treal x)
void getCols(SizesAndBlocks &colsCopy) const
Definition: mat_utils.h:99