標籤:

BLAS簡介

對於機器學習來說,線性代數庫是必不可少的功能,通常使用的都是實際上的標準blas庫來完成矩陣運算。

blas庫包含四類數據的計算:單精度實數s,雙精度實數d,單精度複數c,雙精度複數z。在機器學習中,最常用的是單精度實數。nvidia現在搞出來了半精度計算,但是現在用的並不多。

計算速度來說,自然是越快越好,在機器學習中大量依賴的是gemm(矩陣乘)和gemv(矩陣乘向量)。這裡不建議使用blas網站上的fortran示例,這個庫的實現很慢。現在計算速度比較快的blas庫是intel的mkl,非商業的實現可以用openblas,後者最好自己編譯一遍,以適合自己電腦的cpu。在gpu上計算一般使用cublas,號稱速度是mkl的20倍左右。pascal架構的顯卡計算雙精度的速度遠遠慢於單精度(Tesla除外),所以使用雙精度時gpu速度一般會反而慢於cpu。

如果想讓你的庫同時適應單精度和雙精度,那麼就要對介面做一番設計。blas中每個相同的功能都針對4種數據類型做了4遍,在c++裡面則可以自己將其改為重載。示例代碼在本文結尾,將cblas和cublas都調整為完全相同的介面。

當然先做一個基類,再用虛繼承也是可以的,這樣在實際編寫的時候代碼會很簡化。但是無法在編譯期做inline,所以一般我並不這樣做。而且,即使全部使用gpu計算,仍然難免需要進行少量的cpu計算(例如在驗證期),所以虛繼承可能並不是一個好辦法。

blas_types.h#pragma oncetypedef enum{ MATRIX_NO_TRANS = 0, MATRIX_TRANS = 1,} MatrixTransType;typedef enum{ MATRIX_LOWER = 0, MATRIX_UPPER = 1} MatrixFillType;typedef enum{ MATRIX_NON_UNIT = 0, MATRIX_UNIT = 1} MatrixDiagType;typedef enum{ MATRIX_LEFT = 0, MATRIX_RIGHT = 1} MatrixSideType;cblas_real.h#pragma once#include "cblas.h"#include "blas_types.h"//Class of blas, overload functions with the same name for float and double.class Cblas : Blas{#ifdef VIRTUAL_BLASpublic:#elseprivate:#endif Cblas() {} ~Cblas() {}private: BLAS_FUNC CBLAS_TRANSPOSE get_trans(MatrixTransType t) { return t == MATRIX_NO_TRANS ? CblasNoTrans : CblasTrans; } BLAS_FUNC CBLAS_UPLO get_uplo(MatrixFillType t) { return t == MATRIX_UPPER ? CblasUpper : CblasLower; } BLAS_FUNC CBLAS_DIAG get_diag(MatrixDiagType t) { return t == MATRIX_NON_UNIT ? CblasNonUnit : CblasUnit; } BLAS_FUNC CBLAS_SIDE get_side(MatrixSideType t) { return t == MATRIX_LEFT ? CblasLeft : CblasRight; }public: BLAS_FUNC float dot(const int N, const float* X, const int incX, const float* Y, const int incY) { return cblas_sdot(N, X, incX, Y, incY); } BLAS_FUNC double dot(const int N, const double* X, const int incX, const double* Y, const int incY) { return cblas_ddot(N, X, incX, Y, incY); } BLAS_FUNC float nrm2(const int N, const float* X, const int incX) { return cblas_snrm2(N, X, incX); } BLAS_FUNC float asum(const int N, const float* X, const int incX) { return cblas_sasum(N, X, incX); } BLAS_FUNC double nrm2(const int N, const double* X, const int incX) { return cblas_dnrm2(N, X, incX); } BLAS_FUNC double asum(const int N, const double* X, const int incX) { return cblas_dasum(N, X, incX); } BLAS_FUNC int iamax(const int N, const float* X, const int incX) { return int(cblas_isamax(N, X, incX)); } BLAS_FUNC int iamax(const int N, const double* X, const int incX) { return int(cblas_idamax(N, X, incX)); } BLAS_FUNC void swap(const int N, float* X, const int incX, float* Y, const int incY) { cblas_sswap(N, X, incX, Y, incY); } BLAS_FUNC void copy(const int N, const float* X, const int incX, float* Y, const int incY) { cblas_scopy(N, X, incX, Y, incY); } BLAS_FUNC void axpy(const int N, const float alpha, const float* X, const int incX, float* Y, const int incY) { cblas_saxpy(N, alpha, X, incX, Y, incY); } BLAS_FUNC void swap(const int N, double* X, const int incX, double* Y, const int incY) { cblas_dswap(N, X, incX, Y, incY); } BLAS_FUNC void copy(const int N, const double* X, const int incX, double* Y, const int incY) { cblas_dcopy(N, X, incX, Y, incY); } BLAS_FUNC void axpy(const int N, const double alpha, const double* X, const int incX, double* Y, const int incY) { cblas_daxpy(N, alpha, X, incX, Y, incY); } BLAS_FUNC void rotg(float* a, float* b, float* c, float* s) { cblas_srotg(a, b, c, s); } BLAS_FUNC void rotmg(float* d1, float* d2, float* b1, const float b2, float* P) { cblas_srotmg(d1, d2, b1, b2, P); } BLAS_FUNC void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) { cblas_srot(N, X, incX, Y, incY, c, s); } BLAS_FUNC void rotm(const int N, float* X, const int incX, float* Y, const int incY, const float* P) { cblas_srotm(N, X, incX, Y, incY, P); } BLAS_FUNC void rotg(double* a, double* b, double* c, double* s) { cblas_drotg(a, b, c, s); } BLAS_FUNC void rotmg(double* d1, double* d2, double* b1, const double b2, double* P) { cblas_drotmg(d1, d2, b1, b2, P); } BLAS_FUNC void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) { cblas_drot(N, X, incX, Y, incY, c, s); } BLAS_FUNC void rotm(const int N, double* X, const int incX, double* Y, const int incY, const double* P) { cblas_drotm(N, X, incX, Y, incY, P); } BLAS_FUNC void scal(const int N, const float alpha, float* X, const int incX) { cblas_sscal(N, alpha, X, incX); } BLAS_FUNC void scal(const int N, const double alpha, double* X, const int incX) { cblas_dscal(N, alpha, X, incX); } BLAS_FUNC void gemv(const MatrixTransType TransA, const int M, const int N, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY) { cblas_sgemv(CblasColMajor, get_trans(TransA), M, N, alpha, A, lda, X, incX, beta, Y, incY); } BLAS_FUNC void gbmv(const MatrixTransType TransA, const int M, const int N, const int KL, const int KU, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY) { cblas_sgbmv(CblasColMajor, get_trans(TransA), M, N, KL, KU, alpha, A, lda, X, incX, beta, Y, incY); } BLAS_FUNC void trmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const float* A, const int lda, float* X, const int incX) { cblas_strmv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, A, lda, X, incX); } BLAS_FUNC void tbmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const int K, const float* A, const int lda, float* X, const int incX) { cblas_stbmv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, K, A, lda, X, incX); } BLAS_FUNC void tpmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const float* Ap, float* X, const int incX) { cblas_stpmv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, Ap, X, incX); } BLAS_FUNC void trsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const float* A, const int lda, float* X, const int incX) { cblas_strsv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, A, lda, X, incX); } BLAS_FUNC void tbsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const int K, const float* A, const int lda, float* X, const int incX) { cblas_stbsv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, K, A, lda, X, incX); } BLAS_FUNC void tpsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const float* Ap, float* X, const int incX) { cblas_stpsv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, Ap, X, incX); } BLAS_FUNC void gemv(const MatrixTransType TransA, const int M, const int N, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY) { cblas_dgemv(CblasColMajor, get_trans(TransA), M, N, alpha, A, lda, X, incX, beta, Y, incY); } BLAS_FUNC void gbmv(const MatrixTransType TransA, const int M, const int N, const int KL, const int KU, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY) { cblas_dgbmv(CblasColMajor, get_trans(TransA), M, N, KL, KU, alpha, A, lda, X, incX, beta, Y, incY); } BLAS_FUNC void trmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const double* A, const int lda, double* X, const int incX) { cblas_dtrmv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, A, lda, X, incX); } BLAS_FUNC void tbmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const int K, const double* A, const int lda, double* X, const int incX) { cblas_dtbmv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, K, A, lda, X, incX); } BLAS_FUNC void tpmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const double* Ap, double* X, const int incX) { cblas_dtpmv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, Ap, X, incX); } BLAS_FUNC void trsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const double* A, const int lda, double* X, const int incX) { cblas_dtrsv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, A, lda, X, incX); } BLAS_FUNC void tbsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const int K, const double* A, const int lda, double* X, const int incX) { cblas_dtbsv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, K, A, lda, X, incX); } BLAS_FUNC void tpsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const double* Ap, double* X, const int incX) { cblas_dtpsv(CblasColMajor, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, Ap, X, incX); } BLAS_FUNC void symv(const MatrixFillType Uplo, const int N, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY) { cblas_ssymv(CblasColMajor, get_uplo(Uplo), N, alpha, A, lda, X, incX, beta, Y, incY); } BLAS_FUNC void sbmv(const MatrixFillType Uplo, const int N, const int K, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY) { cblas_ssbmv(CblasColMajor, get_uplo(Uplo), N, K, alpha, A, lda, X, incX, beta, Y, incY); } BLAS_FUNC void spmv(const MatrixFillType Uplo, const int N, const float alpha, const float* Ap, const float* X, const int incX, const float beta, float* Y, const int incY) { cblas_sspmv(CblasColMajor, get_uplo(Uplo), N, alpha, Ap, X, incX, beta, Y, incY); } BLAS_FUNC void ger(const int M, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A, const int lda) { cblas_sger(CblasColMajor, M, N, alpha, X, incX, Y, incY, A, lda); } BLAS_FUNC void syr(const MatrixFillType Uplo, const int N, const float alpha, const float* X, const int incX, float* A, const int lda) { cblas_ssyr(CblasColMajor, get_uplo(Uplo), N, alpha, X, incX, A, lda); } BLAS_FUNC void spr(const MatrixFillType Uplo, const int N, const float alpha, const float* X, const int incX, float* Ap) { cblas_sspr(CblasColMajor, get_uplo(Uplo), N, alpha, X, incX, Ap); } BLAS_FUNC void syr2(const MatrixFillType Uplo, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A, const int lda) { cblas_ssyr2(CblasColMajor, get_uplo(Uplo), N, alpha, X, incX, Y, incY, A, lda); } BLAS_FUNC void spr2(const MatrixFillType Uplo, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A) { cblas_sspr2(CblasColMajor, get_uplo(Uplo), N, alpha, X, incX, Y, incY, A); } BLAS_FUNC void symv(const MatrixFillType Uplo, const int N, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY) { cblas_dsymv(CblasColMajor, get_uplo(Uplo), N, alpha, A, lda, X, incX, beta, Y, incY); } BLAS_FUNC void sbmv(const MatrixFillType Uplo, const int N, const int K, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY) { cblas_dsbmv(CblasColMajor, get_uplo(Uplo), N, K, alpha, A, lda, X, incX, beta, Y, incY); } BLAS_FUNC void spmv(const MatrixFillType Uplo, const int N, const double alpha, const double* Ap, const double* X, const int incX, const double beta, double* Y, const int incY) { cblas_dspmv(CblasColMajor, get_uplo(Uplo), N, alpha, Ap, X, incX, beta, Y, incY); } BLAS_FUNC void ger(const int M, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A, const int lda) { cblas_dger(CblasColMajor, M, N, alpha, X, incX, Y, incY, A, lda); } BLAS_FUNC void syr(const MatrixFillType Uplo, const int N, const double alpha, const double* X, const int incX, double* A, const int lda) { cblas_dsyr(CblasColMajor, get_uplo(Uplo), N, alpha, X, incX, A, lda); } BLAS_FUNC void spr(const MatrixFillType Uplo, const int N, const double alpha, const double* X, const int incX, double* Ap) { cblas_dspr(CblasColMajor, get_uplo(Uplo), N, alpha, X, incX, Ap); } BLAS_FUNC void syr2(const MatrixFillType Uplo, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A, const int lda) { cblas_dsyr2(CblasColMajor, get_uplo(Uplo), N, alpha, X, incX, Y, incY, A, lda); } BLAS_FUNC void spr2(const MatrixFillType Uplo, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A) { cblas_dspr2(CblasColMajor, get_uplo(Uplo), N, alpha, X, incX, Y, incY, A); } BLAS_FUNC void gemm(const MatrixTransType TransA, const MatrixTransType TransB, const int M, const int N, const int K, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) { cblas_sgemm(CblasColMajor, get_trans(TransA), get_trans(TransB), M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } BLAS_FUNC void symm(const MatrixSideType Side, const MatrixFillType Uplo, const int M, const int N, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) { cblas_ssymm(CblasColMajor, get_side(Side), get_uplo(Uplo), M, N, alpha, A, lda, B, ldb, beta, C, ldc); } BLAS_FUNC void syrk(const MatrixFillType Uplo, const MatrixTransType Trans, const int N, const int K, const float alpha, const float* A, const int lda, const float beta, float* C, const int ldc) { cblas_ssyrk(CblasColMajor, get_uplo(Uplo), get_trans(Trans), N, K, alpha, A, lda, beta, C, ldc); } BLAS_FUNC void syr2k(const MatrixFillType Uplo, const MatrixTransType Trans, const int N, const int K, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) { cblas_ssyr2k(CblasColMajor, get_uplo(Uplo), get_trans(Trans), N, K, alpha, A, lda, B, ldb, beta, C, ldc); } BLAS_FUNC void trmm(const MatrixSideType Side, const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int M, const int N, const float alpha, const float* A, const int lda, float* B, const int ldb) { cblas_strmm(CblasColMajor, get_side(Side), get_uplo(Uplo), get_trans(TransA), get_diag(Diag), M, N, alpha, A, lda, B, ldb); } BLAS_FUNC void trsm(const MatrixSideType Side, const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int M, const int N, const float alpha, const float* A, const int lda, float* B, const int ldb) { cblas_strsm(CblasColMajor, get_side(Side), get_uplo(Uplo), get_trans(TransA), get_diag(Diag), M, N, alpha, A, lda, B, ldb); } BLAS_FUNC void gemm(const MatrixTransType TransA, const MatrixTransType TransB, const int M, const int N, const int K, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) { cblas_dgemm(CblasColMajor, get_trans(TransA), get_trans(TransB), M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } BLAS_FUNC void symm(const MatrixSideType Side, const MatrixFillType Uplo, const int M, const int N, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) { cblas_dsymm(CblasColMajor, get_side(Side), get_uplo(Uplo), M, N, alpha, A, lda, B, ldb, beta, C, ldc); } BLAS_FUNC void syrk(const MatrixFillType Uplo, const MatrixTransType Trans, const int N, const int K, const double alpha, const double* A, const int lda, const double beta, double* C, const int ldc) { cblas_dsyrk(CblasColMajor, get_uplo(Uplo), get_trans(Trans), N, K, alpha, A, lda, beta, C, ldc); } BLAS_FUNC void syr2k(const MatrixFillType Uplo, const MatrixTransType Trans, const int N, const int K, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) { cblas_dsyr2k(CblasColMajor, get_uplo(Uplo), get_trans(Trans), N, K, alpha, A, lda, B, ldb, beta, C, ldc); } BLAS_FUNC void trmm(const MatrixSideType Side, const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int M, const int N, const double alpha, const double* A, const int lda, double* B, const int ldb) { cblas_dtrmm(CblasColMajor, get_side(Side), get_uplo(Uplo), get_trans(TransA), get_diag(Diag), M, N, alpha, A, lda, B, ldb); } BLAS_FUNC void trsm(const MatrixSideType Side, const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int M, const int N, const double alpha, const double* A, const int lda, double* B, const int ldb) { cblas_dtrsm(CblasColMajor, get_side(Side), get_uplo(Uplo), get_trans(TransA), get_diag(Diag), M, N, alpha, A, lda, B, ldb); }};cublas_real.h#pragma once#ifndef _NO_CUDA#include "cublas_v2.h"#include "blas_types.h"//Class of cublas, overload functions with the same name for float and double.class Cublas : Blas{#ifdef VIRTUAL_BLASpublic:#elseprivate:#endif Cublas() {} ~Cublas() {}private:#ifndef VIRTUAL_BLAS static#endif cublasHandle_t handle_; BLAS_FUNC cublasOperation_t get_trans(MatrixTransType t) { return t == MATRIX_NO_TRANS ? CUBLAS_OP_N : CUBLAS_OP_T; } BLAS_FUNC cublasFillMode_t get_uplo(MatrixFillType t) { return t == MATRIX_UPPER ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; } BLAS_FUNC cublasDiagType_t get_diag(MatrixDiagType t) { return t == MATRIX_NON_UNIT ? CUBLAS_DIAG_NON_UNIT : CUBLAS_DIAG_UNIT; } BLAS_FUNC cublasSideMode_t get_side(MatrixSideType t) { return t == MATRIX_LEFT ? CUBLAS_SIDE_LEFT : CUBLAS_SIDE_RIGHT; }public: static cublasStatus_t init() { return cublasCreate(&handle_); } static void destroy() { cublasDestroy(handle_); } void set_handle(cublasHandle_t h) { handle_ = h; }public: BLAS_FUNC float dot(const int N, const float* X, const int incX, const float* Y, const int incY) { float r; cublasSdot(handle_, N, X, incX, Y, incY, &r); return r; } BLAS_FUNC double dot(const int N, const double* X, const int incX, const double* Y, const int incY) { double r; cublasDdot(handle_, N, X, incX, Y, incY, &r); return r; } BLAS_FUNC float nrm2(const int N, const float* X, const int incX) { float r; cublasSnrm2(handle_, N, X, incX, &r); return r; } BLAS_FUNC float asum(const int N, const float* X, const int incX) { float r; cublasSasum(handle_, N, X, incX, &r); return r; } BLAS_FUNC double nrm2(const int N, const double* X, const int incX) { double r; cublasDnrm2(handle_, N, X, incX, &r); return r; } BLAS_FUNC double asum(const int N, const double* X, const int incX) { double r; cublasDasum(handle_, N, X, incX, &r); return r; } BLAS_FUNC int iamax(const int N, const float* X, const int incX) { int r; cublasIsamax(handle_, N, X, incX, &r); return r - 1; } BLAS_FUNC int iamax(const int N, const double* X, const int incX) { int r; cublasIdamax(handle_, N, X, incX, &r); return r - 1; } BLAS_FUNC void swap(const int N, float* X, const int incX, float* Y, const int incY) { cublasSswap(handle_, N, X, incX, Y, incY); } BLAS_FUNC void copy(const int N, const float* X, const int incX, float* Y, const int incY) { cublasScopy(handle_, N, X, incX, Y, incY); } BLAS_FUNC void axpy(const int N, const float alpha, const float* X, const int incX, float* Y, const int incY) { cublasSaxpy(handle_, N, &alpha, X, incX, Y, incY); } BLAS_FUNC void swap(const int N, double* X, const int incX, double* Y, const int incY) { cublasDswap(handle_, N, X, incX, Y, incY); } BLAS_FUNC void copy(const int N, const double* X, const int incX, double* Y, const int incY) { cublasDcopy(handle_, N, X, incX, Y, incY); } BLAS_FUNC void axpy(const int N, const double alpha, const double* X, const int incX, double* Y, const int incY) { cublasDaxpy(handle_, N, &alpha, X, incX, Y, incY); } BLAS_FUNC void rotg(float* a, float* b, float* c, float* s) { cublasSrotg(handle_, a, b, c, s); } BLAS_FUNC void rotmg(float* d1, float* d2, float* b1, const float b2, float* P) { cublasSrotmg(handle_, d1, d2, b1, &b2, P); } BLAS_FUNC void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) { cublasSrot(handle_, N, X, incX, Y, incY, &c, &s); } BLAS_FUNC void rotm(const int N, float* X, const int incX, float* Y, const int incY, const float* P) { cublasSrotm(handle_, N, X, incX, Y, incY, P); } BLAS_FUNC void rotg(double* a, double* b, double* c, double* s) { cublasDrotg(handle_, a, b, c, s); } BLAS_FUNC void rotmg(double* d1, double* d2, double* b1, const double b2, double* P) { cublasDrotmg(handle_, d1, d2, b1, &b2, P); } BLAS_FUNC void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) { cublasDrot(handle_, N, X, incX, Y, incY, &c, &s); } BLAS_FUNC void rotm(const int N, double* X, const int incX, double* Y, const int incY, const double* P) { cublasDrotm(handle_, N, X, incX, Y, incY, P); } BLAS_FUNC void scal(const int N, const float alpha, float* X, const int incX) { cublasSscal(handle_, N, &alpha, X, incX); } BLAS_FUNC void scal(const int N, const double alpha, double* X, const int incX) { cublasDscal(handle_, N, &alpha, X, incX); } BLAS_FUNC void gemv(const MatrixTransType TransA, const int M, const int N, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY) { cublasSgemv(handle_, get_trans(TransA), M, N, &alpha, A, lda, X, incX, &beta, Y, incY); } BLAS_FUNC void gbmv(const MatrixTransType TransA, const int M, const int N, const int KL, const int KU, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY) { cublasSgbmv(handle_, get_trans(TransA), M, N, KL, KU, &alpha, A, lda, X, incX, &beta, Y, incY); } BLAS_FUNC void trmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const float* A, const int lda, float* X, const int incX) { cublasStrmv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, A, lda, X, incX); } BLAS_FUNC void tbmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const int K, const float* A, const int lda, float* X, const int incX) { cublasStbmv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, K, A, lda, X, incX); } BLAS_FUNC void tpmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const float* Ap, float* X, const int incX) { cublasStpmv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, Ap, X, incX); } BLAS_FUNC void trsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const float* A, const int lda, float* X, const int incX) { cublasStrsv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, A, lda, X, incX); } BLAS_FUNC void tbsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const int K, const float* A, const int lda, float* X, const int incX) { cublasStbsv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, K, A, lda, X, incX); } BLAS_FUNC void tpsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const float* Ap, float* X, const int incX) { cublasStpsv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, Ap, X, incX); } BLAS_FUNC void gemv(const MatrixTransType TransA, const int M, const int N, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY) { cublasDgemv(handle_, get_trans(TransA), M, N, &alpha, A, lda, X, incX, &beta, Y, incY); } BLAS_FUNC void gbmv(const MatrixTransType TransA, const int M, const int N, const int KL, const int KU, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY) { cublasDgbmv(handle_, get_trans(TransA), M, N, KL, KU, &alpha, A, lda, X, incX, &beta, Y, incY); } BLAS_FUNC void trmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const double* A, const int lda, double* X, const int incX) { cublasDtrmv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, A, lda, X, incX); } BLAS_FUNC void tbmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const int K, const double* A, const int lda, double* X, const int incX) { cublasDtbmv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, K, A, lda, X, incX); } BLAS_FUNC void tpmv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const double* Ap, double* X, const int incX) { cublasDtpmv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, Ap, X, incX); } BLAS_FUNC void trsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const double* A, const int lda, double* X, const int incX) { cublasDtrsv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, A, lda, X, incX); } BLAS_FUNC void tbsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const int K, const double* A, const int lda, double* X, const int incX) { cublasDtbsv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, K, A, lda, X, incX); } BLAS_FUNC void tpsv(const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int N, const double* Ap, double* X, const int incX) { cublasDtpsv(handle_, get_uplo(Uplo), get_trans(TransA), get_diag(Diag), N, Ap, X, incX); } BLAS_FUNC void symv(const MatrixFillType Uplo, const int N, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY) { cublasSsymv(handle_, get_uplo(Uplo), N, &alpha, A, lda, X, incX, &beta, Y, incY); } BLAS_FUNC void sbmv(const MatrixFillType Uplo, const int N, const int K, const float alpha, const float* A, const int lda, const float* X, const int incX, const float beta, float* Y, const int incY) { cublasSsbmv(handle_, get_uplo(Uplo), N, K, &alpha, A, lda, X, incX, &beta, Y, incY); } BLAS_FUNC void spmv(const MatrixFillType Uplo, const int N, const float alpha, const float* Ap, const float* X, const int incX, const float beta, float* Y, const int incY) { cublasSspmv(handle_, get_uplo(Uplo), N, &alpha, Ap, X, incX, &beta, Y, incY); } BLAS_FUNC void ger(const int M, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A, const int lda) { cublasSger(handle_, M, N, &alpha, X, incX, Y, incY, A, lda); } BLAS_FUNC void syr(const MatrixFillType Uplo, const int N, const float alpha, const float* X, const int incX, float* A, const int lda) { cublasSsyr(handle_, get_uplo(Uplo), N, &alpha, X, incX, A, lda); } BLAS_FUNC void spr(const MatrixFillType Uplo, const int N, const float alpha, const float* X, const int incX, float* Ap) { cublasSspr(handle_, get_uplo(Uplo), N, &alpha, X, incX, Ap); } BLAS_FUNC void syr2(const MatrixFillType Uplo, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A, const int lda) { cublasSsyr2(handle_, get_uplo(Uplo), N, &alpha, X, incX, Y, incY, A, lda); } BLAS_FUNC void spr2(const MatrixFillType Uplo, const int N, const float alpha, const float* X, const int incX, const float* Y, const int incY, float* A) { cublasSspr2(handle_, get_uplo(Uplo), N, &alpha, X, incX, Y, incY, A); } BLAS_FUNC void symv(const MatrixFillType Uplo, const int N, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY) { cublasDsymv(handle_, get_uplo(Uplo), N, &alpha, A, lda, X, incX, &beta, Y, incY); } BLAS_FUNC void sbmv(const MatrixFillType Uplo, const int N, const int K, const double alpha, const double* A, const int lda, const double* X, const int incX, const double beta, double* Y, const int incY) { cublasDsbmv(handle_, get_uplo(Uplo), N, K, &alpha, A, lda, X, incX, &beta, Y, incY); } BLAS_FUNC void spmv(const MatrixFillType Uplo, const int N, const double alpha, const double* Ap, const double* X, const int incX, const double beta, double* Y, const int incY) { cublasDspmv(handle_, get_uplo(Uplo), N, &alpha, Ap, X, incX, &beta, Y, incY); } BLAS_FUNC void ger(const int M, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A, const int lda) { cublasDger(handle_, M, N, &alpha, X, incX, Y, incY, A, lda); } BLAS_FUNC void syr(const MatrixFillType Uplo, const int N, const double alpha, const double* X, const int incX, double* A, const int lda) { cublasDsyr(handle_, get_uplo(Uplo), N, &alpha, X, incX, A, lda); } BLAS_FUNC void spr(const MatrixFillType Uplo, const int N, const double alpha, const double* X, const int incX, double* Ap) { cublasDspr(handle_, get_uplo(Uplo), N, &alpha, X, incX, Ap); } BLAS_FUNC void syr2(const MatrixFillType Uplo, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A, const int lda) { cublasDsyr2(handle_, get_uplo(Uplo), N, &alpha, X, incX, Y, incY, A, lda); } BLAS_FUNC void spr2(const MatrixFillType Uplo, const int N, const double alpha, const double* X, const int incX, const double* Y, const int incY, double* A) { cublasDspr2(handle_, get_uplo(Uplo), N, &alpha, X, incX, Y, incY, A); } BLAS_FUNC void gemm(const MatrixTransType TransA, const MatrixTransType TransB, const int M, const int N, const int K, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) { cublasSgemm(handle_, get_trans(TransA), get_trans(TransB), M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } BLAS_FUNC void symm(const MatrixSideType Side, const MatrixFillType Uplo, const int M, const int N, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) { cublasSsymm(handle_, get_side(Side), get_uplo(Uplo), M, N, &alpha, A, lda, B, ldb, &beta, C, ldc); } BLAS_FUNC void syrk(const MatrixFillType Uplo, const MatrixTransType Trans, const int N, const int K, const float alpha, const float* A, const int lda, const float beta, float* C, const int ldc) { cublasSsyrk(handle_, get_uplo(Uplo), get_trans(Trans), N, K, &alpha, A, lda, &beta, C, ldc); } BLAS_FUNC void syr2k(const MatrixFillType Uplo, const MatrixTransType Trans, const int N, const int K, const float alpha, const float* A, const int lda, const float* B, const int ldb, const float beta, float* C, const int ldc) { cublasSsyr2k(handle_, get_uplo(Uplo), get_trans(Trans), N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } BLAS_FUNC void trmm(const MatrixSideType Side, const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int M, const int N, const float alpha, const float* A, const int lda, float* B, const int ldb) { cublasStrmm(handle_, get_side(Side), get_uplo(Uplo), get_trans(TransA), get_diag(Diag), M, N, &alpha, A, lda, B, ldb, B, ldb); } BLAS_FUNC void trsm(const MatrixSideType Side, const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int M, const int N, const float alpha, const float* A, const int lda, float* B, const int ldb) { cublasStrsm(handle_, get_side(Side), get_uplo(Uplo), get_trans(TransA), get_diag(Diag), M, N, &alpha, A, lda, B, ldb); } BLAS_FUNC void gemm(const MatrixTransType TransA, const MatrixTransType TransB, const int M, const int N, const int K, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) { cublasDgemm(handle_, get_trans(TransA), get_trans(TransB), M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } BLAS_FUNC void symm(const MatrixSideType Side, const MatrixFillType Uplo, const int M, const int N, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) { cublasDsymm(handle_, get_side(Side), get_uplo(Uplo), M, N, &alpha, A, lda, B, ldb, &beta, C, ldc); } BLAS_FUNC void syrk(const MatrixFillType Uplo, const MatrixTransType Trans, const int N, const int K, const double alpha, const double* A, const int lda, const double beta, double* C, const int ldc) { cublasDsyrk(handle_, get_uplo(Uplo), get_trans(Trans), N, K, &alpha, A, lda, &beta, C, ldc); } BLAS_FUNC void syr2k(const MatrixFillType Uplo, const MatrixTransType Trans, const int N, const int K, const double alpha, const double* A, const int lda, const double* B, const int ldb, const double beta, double* C, const int ldc) { cublasDsyr2k(handle_, get_uplo(Uplo), get_trans(Trans), N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } BLAS_FUNC void trmm(const MatrixSideType Side, const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int M, const int N, const double alpha, const double* A, const int lda, double* B, const int ldb) { cublasDtrmm(handle_, get_side(Side), get_uplo(Uplo), get_trans(TransA), get_diag(Diag), M, N, &alpha, A, lda, B, ldb, B, ldb); } BLAS_FUNC void trsm(const MatrixSideType Side, const MatrixFillType Uplo, const MatrixTransType TransA, const MatrixDiagType Diag, const int M, const int N, const double alpha, const double* A, const int lda, double* B, const int ldb) { cublasDtrsm(handle_, get_side(Side), get_uplo(Uplo), get_trans(TransA), get_diag(Diag), M, N, &alpha, A, lda, B, ldb); } //extensions of cublas BLAS_FUNC void geam(const MatrixTransType TransA, const MatrixTransType TransB, int m, int n, const float alpha, const float* A, int lda, const float beta, const float* B, int ldb, float* C, int ldc) { cublasSgeam(handle_, get_trans(TransA), get_trans(TransB), m, n, &alpha, A, lda, &beta, B, lda, C, ldc); } BLAS_FUNC void geam(const MatrixTransType TransA, const MatrixTransType TransB, int m, int n, const double alpha, const double* A, int lda, const double beta, const double* B, int ldb, double* C, int ldc) { cublasDgeam(handle_, get_trans(TransA), get_trans(TransB), m, n, &alpha, A, lda, &beta, B, lda, C, ldc); } BLAS_FUNC void dgem(MatrixSideType Side, int m, int n, const float* A, int lda, const float* x, int incx, float* C, int ldc) { cublasSdgmm(handle_, get_side(Side), m, n, A, lda, x, incx, C, ldc); } BLAS_FUNC void dgem(MatrixSideType Side, int m, int n, const double* A, int lda, const double* x, int incx, double* C, int ldc) { cublasDdgmm(handle_, get_side(Side), m, n, A, lda, x, incx, C, ldc); }};#endif

推薦閱讀:

學習機器學習時需要儘早知道的三件事
機器學習--感知機科普入門
Python · 神經網路(六)· 拓展
One-Page AlphaGo -- 10分鐘看懂AlphaGo的核心演算法
跟我學做菜吧!sklearn快速上手!用樸素貝葉斯/SVM分析新聞主題

TAG:机器学习 |