Basix
Loading...
Searching...
No Matches
math.h
1// Copyright (C) 2021 Igor Baratta
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include <array>
10#include <cmath>
11#include <concepts>
12#include <span>
13#include <stdexcept>
14#include <string>
15#include <utility>
16#include <vector>
17
18#include "mdspan.hpp"
19
20extern "C"
21{
22 void ssyevd_(char* jobz, char* uplo, int* n, float* a, int* lda, float* w,
23 float* work, int* lwork, int* iwork, int* liwork, int* info);
24 void dsyevd_(char* jobz, char* uplo, int* n, double* a, int* lda, double* w,
25 double* work, int* lwork, int* iwork, int* liwork, int* info);
26
27 void sgesv_(int* N, int* NRHS, float* A, int* LDA, int* IPIV, float* B,
28 int* LDB, int* INFO);
29 void dgesv_(int* N, int* NRHS, double* A, int* LDA, int* IPIV, double* B,
30 int* LDB, int* INFO);
31
32 void sgemm_(char* transa, char* transb, int* m, int* n, int* k, float* alpha,
33 float* a, int* lda, float* b, int* ldb, float* beta, float* c,
34 int* ldc);
35 void dgemm_(char* transa, char* transb, int* m, int* n, int* k, double* alpha,
36 double* a, int* lda, double* b, int* ldb, double* beta, double* c,
37 int* ldc);
38
39 int sgetrf_(const int* m, const int* n, float* a, const int* lda, int* lpiv,
40 int* info);
41 int dgetrf_(const int* m, const int* n, double* a, const int* lda, int* lpiv,
42 int* info);
43}
44
49namespace basix::math
50{
51namespace impl
52{
57template <std::floating_point T>
58void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
59 std::span<const T> B, std::array<std::size_t, 2> Bshape,
60 std::span<T> C)
61{
62 static_assert(std::is_same_v<T, float> or std::is_same_v<T, double>);
63
64 assert(Ashape[1] == Bshape[0]);
65 assert(C.size() == Ashape[0] * Bshape[1]);
66
67 int M = Ashape[0];
68 int N = Bshape[1];
69 int K = Ashape[1];
70
71 T alpha = 1;
72 T beta = 0;
73 int lda = K;
74 int ldb = N;
75 int ldc = N;
76 char trans = 'N';
77 if constexpr (std::is_same_v<T, float>)
78 {
79 sgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
80 const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
81 }
82 else if constexpr (std::is_same_v<T, double>)
83 {
84 dgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
85 const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
86 }
87}
88
89} // namespace impl
90
95template <typename U, typename V>
96std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
97outer(const U& u, const V& v)
98{
99 std::vector<typename U::value_type> result(u.size() * v.size());
100 for (std::size_t i = 0; i < u.size(); ++i)
101 for (std::size_t j = 0; j < v.size(); ++j)
102 result[i * v.size() + j] = u[i] * v[j];
103 return {std::move(result), {u.size(), v.size()}};
104}
105
110template <typename U, typename V>
111std::array<typename U::value_type, 3> cross(const U& u, const V& v)
112{
113 assert(u.size() == 3);
114 assert(v.size() == 3);
115 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
116 u[0] * v[1] - u[1] * v[0]};
117}
118
125template <std::floating_point T>
126std::pair<std::vector<T>, std::vector<T>> eigh(std::span<const T> A,
127 std::size_t n)
128{
129 // Copy A
130 std::vector<T> M(A.begin(), A.end());
131
132 // Allocate storage for eigenvalues
133 std::vector<T> w(n, 0);
134
135 int N = n;
136 char jobz = 'V'; // Compute eigenvalues and eigenvectors
137 char uplo = 'L'; // Lower
138 int ldA = n;
139 int lwork = -1;
140 int liwork = -1;
141 int info;
142 std::vector<T> work(1);
143 std::vector<int> iwork(1);
144
145 // Query optimal workspace size
146 if constexpr (std::is_same_v<T, float>)
147 {
148 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
149 iwork.data(), &liwork, &info);
150 }
151 else if constexpr (std::is_same_v<T, double>)
152 {
153 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
154 iwork.data(), &liwork, &info);
155 }
156
157 if (info != 0)
158 throw std::runtime_error("Could not find workspace size for syevd.");
159
160 // Solve eigen problem
161 work.resize(work[0]);
162 iwork.resize(iwork[0]);
163 lwork = work.size();
164 liwork = iwork.size();
165 if constexpr (std::is_same_v<T, float>)
166 {
167 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
168 iwork.data(), &liwork, &info);
169 }
170 else if constexpr (std::is_same_v<T, double>)
171 {
172 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
173 iwork.data(), &liwork, &info);
174 }
175 if (info != 0)
176 throw std::runtime_error("Eigenvalue computation did not converge.");
177
178 return {std::move(w), std::move(M)};
179}
180
185template <std::floating_point T>
186std::vector<T>
187solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
188 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
189 A,
190 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
191 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
192 B)
193{
194 namespace stdex
195 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
196
197 // Copy A and B to column-major storage
198 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
199 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
200 _A(A.extents()), _B(B.extents());
201 for (std::size_t i = 0; i < A.extent(0); ++i)
202 for (std::size_t j = 0; j < A.extent(1); ++j)
203 _A(i, j) = A(i, j);
204 for (std::size_t i = 0; i < B.extent(0); ++i)
205 for (std::size_t j = 0; j < B.extent(1); ++j)
206 _B(i, j) = B(i, j);
207
208 int N = _A.extent(0);
209 int nrhs = _B.extent(1);
210 int lda = _A.extent(0);
211 int ldb = _B.extent(0);
212 // Pivot indices that define the permutation matrix for the LU solver
213 std::vector<int> piv(N);
214 int info;
215 if constexpr (std::is_same_v<T, float>)
216 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
217 else if constexpr (std::is_same_v<T, double>)
218 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
219 if (info != 0)
220 throw std::runtime_error("Call to dgesv failed: " + std::to_string(info));
221
222 // Copy result to row-major storage
223 std::vector<T> rb(_B.extent(0) * _B.extent(1));
224 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
225 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
226 r(rb.data(), _B.extents());
227 for (std::size_t i = 0; i < _B.extent(0); ++i)
228 for (std::size_t j = 0; j < _B.extent(1); ++j)
229 r(i, j) = _B(i, j);
230
231 return rb;
232}
233
237template <std::floating_point T>
239 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
240 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
241 A)
242{
243 // Copy to column major matrix
244 namespace stdex
245 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
246 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
247 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
248 _A(A.extents());
249 for (std::size_t i = 0; i < A.extent(0); ++i)
250 for (std::size_t j = 0; j < A.extent(1); ++j)
251 _A(i, j) = A(i, j);
252
253 std::vector<T> B(A.extent(1), 1);
254 int N = _A.extent(0);
255 int nrhs = 1;
256 int lda = _A.extent(0);
257 int ldb = B.size();
258
259 // Pivot indices that define the permutation matrix for the LU solver
260 std::vector<int> piv(N);
261 int info;
262 if constexpr (std::is_same_v<T, float>)
263 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
264 else if constexpr (std::is_same_v<T, double>)
265 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
266
267 if (info < 0)
268 {
269 throw std::runtime_error("dgesv failed due to invalid value: "
270 + std::to_string(info));
271 }
272 else if (info > 0)
273 return true;
274 else
275 return false;
276}
277
283template <std::floating_point T>
284std::vector<std::size_t>
285transpose_lu(std::pair<std::vector<T>, std::array<std::size_t, 2>>& A)
286{
287 std::size_t dim = A.second[0];
288 assert(dim == A.second[1]);
289 int N = dim;
290 int info;
291 std::vector<int> lu_perm(dim);
292
293 // Comput LU decomposition of M
294 if constexpr (std::is_same_v<T, float>)
295 sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
296 else if constexpr (std::is_same_v<T, double>)
297 dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
298
299 if (info != 0)
300 {
301 throw std::runtime_error("LU decomposition failed: "
302 + std::to_string(info));
303 }
304
305 std::vector<std::size_t> perm(dim);
306 for (std::size_t i = 0; i < dim; ++i)
307 perm[i] = static_cast<std::size_t>(lu_perm[i] - 1);
308
309 return perm;
310}
311
317template <typename U, typename V, typename W>
318void dot(const U& A, const V& B, W&& C)
319{
320 assert(A.extent(1) == B.extent(0));
321 assert(C.extent(0) == A.extent(0));
322 assert(C.extent(1) == B.extent(1));
323 if (A.extent(0) * B.extent(1) * A.extent(1) < 512)
324 {
325 std::fill_n(C.data_handle(), C.extent(0) * C.extent(1), 0);
326 for (std::size_t i = 0; i < A.extent(0); ++i)
327 for (std::size_t j = 0; j < B.extent(1); ++j)
328 for (std::size_t k = 0; k < A.extent(1); ++k)
329 C(i, j) += A(i, k) * B(k, j);
330 }
331 else
332 {
333 using T = typename std::decay_t<U>::value_type;
334 impl::dot_blas<T>(
335 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
336 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
337 std::span(C.data_handle(), C.size()));
338 }
339}
340
344template <std::floating_point T>
345std::vector<T> eye(std::size_t n)
346{
347 std::vector<T> I(n * n, 0);
348 namespace stdex
349 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
350 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
351 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
352 Iview(I.data(), n, n);
353 for (std::size_t i = 0; i < n; ++i)
354 Iview(i, i) = 1;
355 return I;
356}
357
362template <std::floating_point T>
364 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
365 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
366 wcoeffs,
367 std::size_t start = 0)
368{
369 for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
370 {
371 T norm = 0;
372 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
373 norm += wcoeffs(i, k) * wcoeffs(i, k);
374
375 norm = std::sqrt(norm);
376 if (norm < 2 * std::numeric_limits<T>::epsilon())
377 {
378 throw std::runtime_error(
379 "Cannot orthogonalise the rows of a matrix with incomplete row rank");
380 }
381
382 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
383 wcoeffs(i, k) /= norm;
384
385 for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
386 {
387 T a = 0;
388 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
389 a += wcoeffs(i, k) * wcoeffs(j, k);
390 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
391 wcoeffs(j, k) -= a * wcoeffs(i, k);
392 }
393 }
394}
395} // namespace basix::math
A finite element.
Definition finite-element.h:139
Mathematical functions.
Definition math.h:50
void dot(const U &A, const V &B, W &&C)
Compute C = A * B.
Definition math.h:318
std::vector< T > solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > A, MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > B)
Solve A X = B.
Definition math.h:187
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Definition math.h:111
bool is_singular(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > A)
Check if A is a singular matrix,.
Definition math.h:238
std::vector< std::size_t > transpose_lu(std::pair< std::vector< T >, std::array< std::size_t, 2 > > &A)
Compute the LU decomposition of the transpose of a square matrix A.
Definition math.h:285
void orthogonalise(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > wcoeffs, std::size_t start=0)
Orthogonalise the rows of a matrix (in place).
Definition math.h:363
std::pair< std::vector< T >, std::vector< T > > eigh(std::span< const T > A, std::size_t n)
Definition math.h:126
std::vector< T > eye(std::size_t n)
Build an identity matrix.
Definition math.h:345
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition math.h:97