ergo
mat_utils.h
Go to the documentation of this file.
1 /* Ergo, version 3.8, a program for linear scaling electronic structure
2  * calculations.
3  * Copyright (C) 2019 Elias Rudberg, Emanuel H. Rubensson, Pawel Salek,
4  * and Anastasia Kruchinina.
5  *
6  * This program is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program. If not, see <http://www.gnu.org/licenses/>.
18  *
19  * Primary academic reference:
20  * Ergo: An open-source program for linear-scaling electronic structure
21  * calculations,
22  * Elias Rudberg, Emanuel H. Rubensson, Pawel Salek, and Anastasia
23  * Kruchinina,
24  * SoftwareX 7, 107 (2018),
25  * <http://dx.doi.org/10.1016/j.softx.2018.03.005>
26  *
27  * For further information about Ergo, see <http://www.ergoscf.org>.
28  */
29 
36 #ifndef MAT_UTILS_HEADER
37 #define MAT_UTILS_HEADER
38 #include "Interval.h"
39 #include "matrix_proxy.h"
40 namespace mat {
41 
42  template<typename Tmatrix, typename Treal>
43  struct DiffMatrix {
44  typedef typename Tmatrix::VectorType VectorType;
45  void getCols(SizesAndBlocks & colsCopy) const {
46  A.getCols(colsCopy);
47  }
48  int get_nrows() const {
49  assert( A.get_nrows() == B.get_nrows() );
50  return A.get_nrows();
51  }
52  Treal frob() const {
53  return Tmatrix::frob_diff(A, B);
54  }
55  void quickEuclBounds(Treal & euclLowerBound,
56  Treal & euclUpperBound) const {
57  Treal frobTmp = frob();
58  euclLowerBound = frobTmp / template_blas_sqrt( (Treal)get_nrows() );
59  euclUpperBound = frobTmp;
60  }
61 
62  Tmatrix const & A;
63  Tmatrix const & B;
64  DiffMatrix(Tmatrix const & A_, Tmatrix const & B_)
65  : A(A_), B(B_) {}
66  template<typename Tvector>
67  void matVecProd(Tvector & y, Tvector const & x) const {
68  Tvector tmp(y);
69  tmp = (Treal)-1.0 * B * x; // -B * x
70  y = (Treal)1.0 * A * x; // A * x
71  y += (Treal)1.0 * tmp; // A * x - B * x => (A - B) * x
72  }
73  };
74 
75 
76  // ATAMatrix AT*A
77  template<typename Tmatrix, typename Treal>
78  struct ATAMatrix {
79  typedef typename Tmatrix::VectorType VectorType;
80  Tmatrix const & A;
81  explicit ATAMatrix(Tmatrix const & A_)
82  : A(A_) {}
83  void getCols(SizesAndBlocks & colsCopy) const {
84  A.getRows(colsCopy);
85  }
86  void quickEuclBounds(Treal & euclLowerBound,
87  Treal & euclUpperBound) const {
88  Treal frobA = A.frob();
89  euclLowerBound = 0;
90  euclUpperBound = frobA * frobA;
91  }
92 
93  // y = AT*A*x
94  template<typename Tvector>
95  void matVecProd(Tvector & y, Tvector const & x) const {
96  y = x;
97  y = A * y;
98  y = transpose(A) * y;
99  }
100  // Number of rows of A^T * A is the number of columns of A
101  int get_nrows() const { return A.get_ncols(); }
102  };
103 
104 
105  template<typename Tmatrix, typename Tmatrix2, typename Treal>
106  struct TripleMatrix {
107  typedef typename Tmatrix::VectorType VectorType;
108  void getCols(SizesAndBlocks & colsCopy) const {
109  A.getCols(colsCopy);
110  }
111  int get_nrows() const {
112  assert( A.get_nrows() == Z.get_nrows() );
113  return A.get_nrows();
114  }
115  void quickEuclBounds(Treal & euclLowerBound,
116  Treal & euclUpperBound) const {
117  Treal frobA = A.frob();
118  Treal frobZ = Z.frob();
119  euclLowerBound = 0;
120  euclUpperBound = frobA * frobZ * frobZ;
121  }
122 
123  Tmatrix const & A;
124  Tmatrix2 const & Z;
125  TripleMatrix(Tmatrix const & A_, Tmatrix2 const & Z_)
126  : A(A_), Z(Z_) {}
127  void matVecProd(VectorType & y, VectorType const & x) const {
128  VectorType tmp(x);
129  tmp = Z * tmp; // Z * x
130  y = (Treal)1.0 * A * tmp; // A * Z * x
131  y = transpose(Z) * y; // Z^T * A * Z * x
132  }
133  };
134 
135 
136  template<typename Tmatrix, typename Tmatrix2, typename Treal>
138  typedef typename Tmatrix::VectorType VectorType;
139  void getCols(SizesAndBlocks & colsCopy) const {
140  E.getRows(colsCopy);
141  }
142  int get_nrows() const {
143  return E.get_ncols();
144  }
145  void quickEuclBounds(Treal & euclLowerBound,
146  Treal & euclUpperBound) const {
147  Treal frobA = A.frob();
148  Treal frobZ = Zt.frob();
149  Treal frobE = E.frob();
150  euclLowerBound = 0;
151  euclUpperBound = frobA * frobE * frobE + 2 * frobA * frobE * frobZ;
152  }
153 
154  Tmatrix const & A;
155  Tmatrix2 const & Zt;
156  Tmatrix2 const & E;
157 
158  CongrTransErrorMatrix(Tmatrix const & A_,
159  Tmatrix2 const & Z_,
160  Tmatrix2 const & E_)
161  : A(A_), Zt(Z_), E(E_) {}
162  void matVecProd(VectorType & y, VectorType const & x) const {
163 
164  VectorType tmp(x);
165  tmp = E * tmp; // E * x
166  y = (Treal)-1.0 * A * tmp; // -A * E * x
167  y = transpose(E) * y; // -E^T * A * E * x
168 
169  VectorType tmp1;
170  tmp = x;
171  tmp = Zt * tmp; // Zt * x
172  tmp1 = (Treal)1.0 * A * tmp; // A * Zt * x
173  tmp1 = transpose(E) * tmp1; // E^T * A * Zt * x
174  y += (Treal)1.0 * tmp1;
175 
176  tmp = x;
177  tmp = E * tmp; // E * x
178  tmp1 = (Treal)1.0 * A * tmp; // A * E * x
179  tmp1 = transpose(Zt) * tmp1; // Zt^T * A * E * x
180  y += (Treal)1.0 * tmp1;
181  }
182  };
183 
184 
185 
186 } /* end namespace mat */
187 #endif
template_blas_sqrt
Treal template_blas_sqrt(Treal x)
mat::CongrTransErrorMatrix::Zt
Tmatrix2 const & Zt
Definition: mat_utils.h:155
mat::ATAMatrix::getCols
void getCols(SizesAndBlocks &colsCopy) const
Definition: mat_utils.h:83
mat::CongrTransErrorMatrix
Definition: mat_utils.h:137
mat::ATAMatrix::VectorType
Tmatrix::VectorType VectorType
Definition: mat_utils.h:79
mat::DiffMatrix::B
Tmatrix const & B
Definition: mat_utils.h:63
mat::ATAMatrix::ATAMatrix
ATAMatrix(Tmatrix const &A_)
Definition: mat_utils.h:81
mat::TripleMatrix::getCols
void getCols(SizesAndBlocks &colsCopy) const
Definition: mat_utils.h:108
mat::TripleMatrix::Z
Tmatrix2 const & Z
Definition: mat_utils.h:124
mat::DiffMatrix::DiffMatrix
DiffMatrix(Tmatrix const &A_, Tmatrix const &B_)
Definition: mat_utils.h:64
VectorType
generalVector VectorType
Definition: GetDensFromFock.cc:62
matrix_proxy.h
mat::TripleMatrix::VectorType
Tmatrix::VectorType VectorType
Definition: mat_utils.h:107
mat::CongrTransErrorMatrix::getCols
void getCols(SizesAndBlocks &colsCopy) const
Definition: mat_utils.h:139
mat::ATAMatrix::get_nrows
int get_nrows() const
Definition: mat_utils.h:101
mat::CongrTransErrorMatrix::quickEuclBounds
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition: mat_utils.h:145
mat::ATAMatrix::quickEuclBounds
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition: mat_utils.h:86
mat::CongrTransErrorMatrix::matVecProd
void matVecProd(VectorType &y, VectorType const &x) const
Definition: mat_utils.h:162
mat::ATAMatrix::matVecProd
void matVecProd(Tvector &y, Tvector const &x) const
Definition: mat_utils.h:95
mat::DiffMatrix::get_nrows
int get_nrows() const
Definition: mat_utils.h:48
mat::DiffMatrix::VectorType
Tmatrix::VectorType VectorType
Definition: mat_utils.h:44
mat::TripleMatrix::matVecProd
void matVecProd(VectorType &y, VectorType const &x) const
Definition: mat_utils.h:127
mat::TripleMatrix::A
Tmatrix const & A
Definition: mat_utils.h:123
mat::TripleMatrix
Definition: mat_utils.h:106
mat::CongrTransErrorMatrix::get_nrows
int get_nrows() const
Definition: mat_utils.h:142
mat::DiffMatrix::quickEuclBounds
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition: mat_utils.h:55
mat::CongrTransErrorMatrix::E
Tmatrix2 const & E
Definition: mat_utils.h:156
mat::transpose
Xtrans< TX > transpose(TX const &A)
Transposition.
Definition: matrix_proxy.h:131
mat
Definition: allocate.cc:39
mat::CongrTransErrorMatrix::CongrTransErrorMatrix
CongrTransErrorMatrix(Tmatrix const &A_, Tmatrix2 const &Z_, Tmatrix2 const &E_)
Definition: mat_utils.h:158
mat::CongrTransErrorMatrix::A
Tmatrix const & A
Definition: mat_utils.h:154
mat::ATAMatrix
Definition: mat_utils.h:78
mat::CongrTransErrorMatrix::VectorType
Tmatrix::VectorType VectorType
Definition: mat_utils.h:138
mat::DiffMatrix::matVecProd
void matVecProd(Tvector &y, Tvector const &x) const
Definition: mat_utils.h:67
mat::DiffMatrix::frob
Treal frob() const
Definition: mat_utils.h:52
mat::ATAMatrix::A
Tmatrix const & A
Definition: mat_utils.h:80
mat::TripleMatrix::quickEuclBounds
void quickEuclBounds(Treal &euclLowerBound, Treal &euclUpperBound) const
Definition: mat_utils.h:115
mat::DiffMatrix::A
Tmatrix const & A
Definition: mat_utils.h:62
mat::SizesAndBlocks
Describes dimensions of matrix and its blocks on all levels.
Definition: SizesAndBlocks.h:45
mat::TripleMatrix::TripleMatrix
TripleMatrix(Tmatrix const &A_, Tmatrix2 const &Z_)
Definition: mat_utils.h:125
mat::DiffMatrix::getCols
void getCols(SizesAndBlocks &colsCopy) const
Definition: mat_utils.h:45
mat::DiffMatrix
Definition: mat_utils.h:43
mat::TripleMatrix::get_nrows
int get_nrows() const
Definition: mat_utils.h:111
Interval.h