58void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
59 std::span<const T> B, std::array<std::size_t, 2> Bshape,
79 sgemm_(&trans, &trans, &N, &M, &K, &alpha,
const_cast<T*
>(B.data()), &ldb,
80 const_cast<T*
>(A.data()), &lda, &beta, C.data(), &ldc);
84 dgemm_(&trans, &trans, &N, &M, &K, &alpha,
const_cast<T*
>(B.data()), &ldb,
85 const_cast<T*
>(A.data()), &lda, &beta, C.data(), &ldc);
99 std::vector<typename U::value_type> result(u.size() * v.size());
100 for (std::size_t i = 0; i < u.size(); ++i)
101 for (std::size_t j = 0; j < v.size(); ++j)
102 result[i * v.size() + j] = u[i] * v[j];
103 return {std::move(result), {u.size(), v.size()}};
126std::pair<std::vector<T>, std::vector<T>>
eigh(std::span<const T> A,
130 std::vector<T> M(A.begin(), A.end());
133 std::vector<T> w(n, 0);
142 std::vector<T> work(1);
143 std::vector<int> iwork(1);
146 if constexpr (std::is_same_v<T, float>)
148 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
149 iwork.data(), &liwork, &info);
151 else if constexpr (std::is_same_v<T, double>)
153 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
154 iwork.data(), &liwork, &info);
158 throw std::runtime_error(
"Could not find workspace size for syevd.");
161 work.resize(work[0]);
162 iwork.resize(iwork[0]);
164 liwork = iwork.size();
165 if constexpr (std::is_same_v<T, float>)
167 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
168 iwork.data(), &liwork, &info);
170 else if constexpr (std::is_same_v<T, double>)
172 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
173 iwork.data(), &liwork, &info);
176 throw std::runtime_error(
"Eigenvalue computation did not converge.");
178 return {std::move(w), std::move(M)};
187solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
188 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
190 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
191 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
195 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
198 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
199 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
200 _A(A.extents()), _B(B.extents());
201 for (std::size_t i = 0; i < A.extent(0); ++i)
202 for (std::size_t j = 0; j < A.extent(1); ++j)
204 for (std::size_t i = 0; i < B.extent(0); ++i)
205 for (std::size_t j = 0; j < B.extent(1); ++j)
208 int N = _A.extent(0);
209 int nrhs = _B.extent(1);
210 int lda = _A.extent(0);
211 int ldb = _B.extent(0);
213 std::vector<int> piv(N);
215 if constexpr (std::is_same_v<T, float>)
216 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
217 else if constexpr (std::is_same_v<T, double>)
218 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
220 throw std::runtime_error(
"Call to dgesv failed: " + std::to_string(info));
223 std::vector<T> rb(_B.extent(0) * _B.extent(1));
224 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
225 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
226 r(rb.data(), _B.extents());
227 for (std::size_t i = 0; i < _B.extent(0); ++i)
228 for (std::size_t j = 0; j < _B.extent(1); ++j)
239 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
240 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
245 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
246 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
247 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
249 for (std::size_t i = 0; i < A.extent(0); ++i)
250 for (std::size_t j = 0; j < A.extent(1); ++j)
253 std::vector<T> B(A.extent(1), 1);
254 int N = _A.extent(0);
256 int lda = _A.extent(0);
260 std::vector<int> piv(N);
262 if constexpr (std::is_same_v<T, float>)
263 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
264 else if constexpr (std::is_same_v<T, double>)
265 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
269 throw std::runtime_error(
"dgesv failed due to invalid value: "
270 + std::to_string(info));
287 std::size_t dim = A.second[0];
288 assert(dim == A.second[1]);
291 std::vector<int> lu_perm(dim);
294 if constexpr (std::is_same_v<T, float>)
295 sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
296 else if constexpr (std::is_same_v<T, double>)
297 dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
301 throw std::runtime_error(
"LU decomposition failed: "
302 + std::to_string(info));
305 std::vector<std::size_t> perm(dim);
306 for (std::size_t i = 0; i < dim; ++i)
307 perm[i] =
static_cast<std::size_t
>(lu_perm[i] - 1);
318void dot(
const U& A,
const V& B, W&& C)
320 assert(A.extent(1) == B.extent(0));
321 assert(C.extent(0) == A.extent(0));
322 assert(C.extent(1) == B.extent(1));
323 if (A.extent(0) * B.extent(1) * A.extent(1) < 512)
325 std::fill_n(C.data_handle(), C.extent(0) * C.extent(1), 0);
326 for (std::size_t i = 0; i < A.extent(0); ++i)
327 for (std::size_t j = 0; j < B.extent(1); ++j)
328 for (std::size_t k = 0; k < A.extent(1); ++k)
329 C(i, j) += A(i, k) * B(k, j);
333 using T =
typename std::decay_t<U>::value_type;
335 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
336 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
337 std::span(C.data_handle(), C.size()));
345std::vector<T>
eye(std::size_t n)
347 std::vector<T> I(n * n, 0);
349 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
350 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
351 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
352 Iview(I.data(), n, n);
353 for (std::size_t i = 0; i < n; ++i)
364 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
365 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
367 std::size_t start = 0)
369 for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
372 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
373 norm += wcoeffs(i, k) * wcoeffs(i, k);
375 norm = std::sqrt(norm);
376 if (norm < 2 * std::numeric_limits<T>::epsilon())
378 throw std::runtime_error(
379 "Cannot orthogonalise the rows of a matrix with incomplete row rank");
382 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
383 wcoeffs(i, k) /= norm;
385 for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
388 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
389 a += wcoeffs(i, k) * wcoeffs(j, k);
390 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
391 wcoeffs(j, k) -= a * wcoeffs(i, k);