OpenBLAS矩阵乘法源码结构分析
用于记录阅读分析OpenBLAS源代码的各种知识点,防止遗忘。这里主要记录OpenBLAS的代码结构,因为确实比较复杂,直接看源代码很可能比较蒙比,如果知道其结构,看起来就比较轻松了。至于OpenBLAS矩阵乘法的算法,这篇不涉及,我会在另一篇文章中简单(瞎jb)分析。OpenBLAS代码总体上可以分成三个层次:1.接口层在OpenBLAS接口层中,运算又分为三个类型,分别是leve
用于记录阅读分析OpenBLAS源代码的各种知识点,防止遗忘。这里主要记录OpenBLAS的代码结构,因为确实比较复杂,直接看源代码很可能比较蒙比,如果知道其结构,看起来就比较轻松了。至于OpenBLAS矩阵乘法的算法,这篇不涉及,我会在另一篇文章中简单(瞎jb)分析。
OpenBLAS代码总体上可以分成三个层次:
1.接口层
在OpenBLAS接口层中,运算又分为三个类型,分别是level1到3,其中level3对应矩阵和矩阵的运算,level2和level1依次维度越来越低。不过这些level1~3都是BLAS内部计算时使用的接口(源代码基本在driver/level下),对外界用户的接口是不涉及这个概念的(对外接口基本都在interface文件夹下):
每一个源代码文件对应一种操作,如这里的gemm指的是普通矩阵乘法(General Matrix Multiplication)而gemv指的是普通矩阵向量乘法(General Matrix Vector)。打开gemm.c可以大致观察一下其源代码(大幅度阉割版):
//后边函数体中要使用的函数表,是计算矩阵乘法的核心函数,有这么多不同的函数指针,是区分了大量特殊情况,如GEMM_NN是两个普通矩阵相乘,而GEMM_TN说明第一个矩阵是转置过的,而带THREAD标签的是多线程实现的核心函数,其执行效率是单核执行的若干倍。
static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
#ifndef GEMM3M
GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN,
GEMM_NT, GEMM_TT, GEMM_RT, GEMM_CT,
GEMM_NR, GEMM_TR, GEMM_RR, GEMM_CR,
GEMM_NC, GEMM_TC, GEMM_RC, GEMM_CC,
#if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
GEMM_THREAD_NN, GEMM_THREAD_TN, GEMM_THREAD_RN, GEMM_THREAD_CN,
GEMM_THREAD_NT, GEMM_THREAD_TT, GEMM_THREAD_RT, GEMM_THREAD_CT,
GEMM_THREAD_NR, GEMM_THREAD_TR, GEMM_THREAD_RR, GEMM_THREAD_CR,
GEMM_THREAD_NC, GEMM_THREAD_TC, GEMM_THREAD_RC, GEMM_THREAD_CC,
#endif
#else
GEMM3M_NN, GEMM3M_TN, GEMM3M_RN, GEMM3M_CN,
GEMM3M_NT, GEMM3M_TT, GEMM3M_RT, GEMM3M_CT,
GEMM3M_NR, GEMM3M_TR, GEMM3M_RR, GEMM3M_CR,
GEMM3M_NC, GEMM3M_TC, GEMM3M_RC, GEMM3M_CC,
#if defined(SMP) && !defined(USE_SIMPLE_THREADED_LEVEL3)
GEMM3M_THREAD_NN, GEMM3M_THREAD_TN, GEMM3M_THREAD_RN, GEMM3M_THREAD_CN,
GEMM3M_THREAD_NT, GEMM3M_THREAD_TT, GEMM3M_THREAD_RT, GEMM3M_THREAD_CT,
GEMM3M_THREAD_NR, GEMM3M_THREAD_TR, GEMM3M_THREAD_RR, GEMM3M_THREAD_CR,
GEMM3M_THREAD_NC, GEMM3M_THREAD_TC, GEMM3M_THREAD_RC, GEMM3M_THREAD_CC,
#endif
#endif
};
//gemm的函数体
void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB,
blasint m, blasint n, blasint k,
#ifndef COMPLEX
FLOAT alpha,
#else
FLOAT *alpha,
#endif
FLOAT *a, blasint lda,
FLOAT *b, blasint ldb,
#ifndef COMPLEX
FLOAT beta,
#else
FLOAT *beta,
#endif
FLOAT *c, blasint ldc) {
//大量代码实现,主要是对输入矩阵格式进行操作,并且选择正确分支,最后调用上边函数表中的一个函数完成运算
。
。
。
。
//调用函数表中的一个核心计算函数
(gemm[(transb << 2) | transa])(&args, NULL, NULL, sa, sb, 0);
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
首先比较蒙比的是,其函数名字为CNAME而不是gemm,而且其函数体中有大量ifdef的分支,这是因为OpenBLAS采用了类似模版编程的办法,因为大量操作涉及相同的接口结构,所以OpenBLAS使用同一个源代码文件gemm.c来生成各种不同的矩阵相乘操作的真正代码文件。在使用cmake生成工程之后(生成之前是没有的),我们可以在interface找到一个文件sgemm.c,代码很简洁:
#define ASMNAME _sgemm
#define ASMFNAME _sgemm_
#define NAME sgemm_
#define CNAME sgemm
#define CHAR_NAME "sgemm_"
#define CHAR_CNAME "sgemm"
#include "PATH_TO_YOUR_OPENBLAS/OpenBLAS-develop/interface/gemm.c"
- 1
- 2
- 3
- 4
- 5
- 6
- 7
可以看到,这个sgemm的接口define了CNAME,且include的了上边的gemm.c,这样gemm.c的源代码就会被粘贴进来,而CNAME会被替换成这个接口真正的名字,形成一个有效的编译单元。同理,interface的zgemm,cgemm等等都会include上边的gemm.c,生成一个完整的函数实现。这种写法就省去了大量冗余代码,因为无论zgemm、sgemm、cgemm等等他们都是矩阵乘法,大致实现相同,只是一些细节不同,所以不需要每个函数都从头写。至于zgemm,sgemm这些接口有啥区别,这里有比较详细的说明,大致为C代表复数计算,Z代表双精度复数,而S代表单精度常数(链接中说是半精度,实际上OpenBLAS中是单精度),则zgemm是双精度复数矩阵乘法,他们通过这里的CNAME来区分名字,同时通过很多其他的预编译宏来区分一些细微的实现细节,因为没有详细研究过复数矩阵乘法的源代码,就不多赘述那些分支的区别了。
2.核心函数层
之后可以深入到gemm中函数表里的核心计算函数了。我们挑选最简单的GEMM_NN,两个普通矩阵相乘来看。GEMM_NN这个预编译宏最后指向哪个函数,是和很多其他预编译宏相关的(对预编译宏的谜の热爱),比如define了COMPLEX,表示复数矩阵乘法时,GEMM_NN就指向qgemm_nn函数,而如果是实数矩阵乘法,就会指向sgemm_nn函数。核心函数层依然是那种模版编程的思路,直接在VS里边是索引不到sgemm_nn的实现的,因为sgemm_nn的实现代码如下:
#define NN
#define ASMNAME _sgemm_nn
#define ASMFNAME _sgemm_nn_
#define NAME sgemm_nn_
#define CNAME sgemm_nn
#define CHAR_NAME "sgemm_nn_"
#define CHAR_CNAME "sgemm_nn"
#include "PATH_TO_YOUR_OPENBLAS/OpenBLAS-develop/driver/level3/gemm.c"
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
和上边sgemm.c的代码思路相同,这个源代码中仅仅定义函数的名字,真正的源代码模版在gemm.c中,不过注意了,这里的gemm.c和上边提到的gemm.c不是同一份代码,这个gemm.c在level3下,其内容是真正的矩阵乘法实现,上边的gemm.c干的事情主要是一些准备和分支控制工作(他是暴漏给用户的接口函数,这个是内部的干活函数)。level3下的gemm.c:
#include <stdio.h>
#include "common.h"
#undef TIMING
#ifdef PARAMTEST
#undef GEMM_P
#undef GEMM_Q
#undef GEMM_R
#define GEMM_P (args -> gemm_p)
#define GEMM_Q (args -> gemm_q)
#define GEMM_R (args -> gemm_r)
#endif
#if 0
#undef GEMM_P
#undef GEMM_Q
#define GEMM_P 504
#define GEMM_Q 128
#endif
#ifdef THREADED_LEVEL3
#include "level3_thread.c"
#else
#include "level3.c"
#endif
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
又是故技重施,这里使用预编译宏控制了一些变量,而有效代码都在level3.c中。level3代码(大量精简版):
//这里大量的define比较吓人,但是除去分支控制,主要内容是定义了三种操作(函数),BETA_OPERATION,ICOPY_OPERATION,OCOPY_OPERATION,KERNEL_OPERATION。BETA_OPERATION用于给整个矩阵乘以一个系数,因为gemm计算的是c=alpha*a*b+beta*c,这个BETA_OPERATION就是用于乘以那个beta的。两个COPY函数用于矩阵的拷贝,这个和OpenBLAS的实现算法有关,在算法分析中在详细说明。而KERNEL_OPERATION就是核心函数中的核心函数了,他是真正做乘法和乘累加的地方,整个矩阵乘法就是他算完的了。
#ifndef BETA_OPERATION
#if !defined(XDOUBLE) || !defined(QUAD_PRECISION)
#ifndef COMPLEX
#define BETA_OPERATION(M_FROM, M_TO, N_FROM, N_TO, BETA, C, LDC) \
GEMM_BETA((M_TO) - (M_FROM), (N_TO - N_FROM), 0, \
BETA[0], NULL, 0, NULL, 0, \
(FLOAT *)(C) + ((M_FROM) + (N_FROM) * (LDC)) * COMPSIZE, LDC)
#else
。
。
。
。
#ifndef ICOPY_OPERATION
#if defined(NN) || defined(NT) || defined(NC) || defined(NR) || \
defined(RN) || defined(RT) || defined(RC) || defined(RR)
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else
#define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_INCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif
#endif
#ifndef OCOPY_OPERATION
#if defined(NN) || defined(TN) || defined(CN) || defined(RN) || \
defined(NR) || defined(TR) || defined(CR) || defined(RR)
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
#else
#define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_OTCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
#endif
#endif
。
。
。
。
#ifndef KERNEL_OPERATION
#if !defined(XDOUBLE) || !defined(QUAD_PRECISION)
#ifndef COMPLEX
#define KERNEL_OPERATION(M, N, K, ALPHA, SA, SB, C, LDC, X, Y) \
KERNEL_FUNC(M, N, K, ALPHA[0], SA, SB, (FLOAT *)(C) + ((X) + (Y) * LDC) * COMPSIZE, LDC)
#else
#define KERNEL_OPERATION(M, N, K, ALPHA, SA, SB, C, LDC, X, Y) \
KERNEL_FUNC(M, N, K, ALPHA[0], ALPHA[1], SA, SB, (FLOAT *)(C) + ((X) + (Y) * LDC) * COMPSIZE, LDC)
#endif
#else
#define KERNEL_OPERATION(M, N, K, ALPHA, SA, SB, C, LDC, X, Y) \
KERNEL_FUNC(M, N, K, ALPHA, SA, SB, (FLOAT *)(C) + ((X) + (Y) * LDC) * COMPSIZE, LDC)
#endif
#endif
。
。
。
。
//sgemm_nn的函数体,应该说是所有gemm的模版函数体,负责将矩阵乘法完成
int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
XFLOAT *sa, XFLOAT *sb, BLASLONG dummy){
。
。
。
。
//对矩阵c调用beta操作(c=alpha*a*b+beta*c)
BETA_OPERATION(m_from, m_to, n_from, n_to, beta, c, ldc);
。
。
。
。
//大量复杂的for循环控制,因为这一篇不详细分析算法细节,这里就简化了,源代码中是多个不同层级for嵌套,用于分片
for(。。。。。。)
{
//进行一次矩阵拷贝,大致意义是从两个操作数矩阵之一中切出一小块,而这一小块会被放入L1Cache中反复使用,提高缓存命中率以达到加速效果
ICOPY_OPERATION(min_l, min_i, a, lda, ls, m_from, sa);
//同上
OCOPY_OPERATION(min_l, min_jj, b, ldb, ls, jjs, sb + min_l * (jjs - js) * COMPSIZE * l1stride);
//对切出的小块进行矩阵乘法
KERNEL_OPERATION(min_i, min_jj, min_l, alpha, sa, sb + min_l * (jjs - js) * COMPSIZE * l1stride, c, ldc, m_from, jjs);
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
可见sgemm_nn已经完成了两个不转置矩阵的乘法了,但是sgemm_nn中的BETA_OPERATION,ICOPY_OPERATION,OCOPY_OPERATION,KERNEL_OPERATION这三个最热点的关键函数又是怎么实现的?你会发现VS f12是找不到这三个函数的实现的,因为他们又是使用模版的思路编写的。。。如这里的BETA_OPERATION最后指向sgemm_beta,我们可以找到sgemm_beta.c:
#define ASMNAME _sgemm_beta
#define ASMFNAME _sgemm_beta_
#define NAME sgemm_beta_
#define CNAME sgemm_beta
#define CHAR_NAME "sgemm_beta_"
#define CHAR_CNAME "sgemm_beta"
#include "PATH_TO_YOUR_OPENBLAS/OpenBLAS-develop/kernel/x86/../generic/gemm_beta.c"
- 1
- 2
- 3
- 4
- 5
- 6
- 7
可见实现代码在kernel文件夹下的gemm_beta.c中。
不过这里有一点值得一提,就是使用VS编译OpenBLAS时,这三个最核心最热点的OPERATION都是像上边sgemm_beta.c一样,最后是使用c代码实现的,其执行效率非常低,而使用官方推荐mingw或者在其他平台,如arm上使用cmake编译时,这些OPERATION会使用更高效的汇编.S代码实现(这个不是编译器的锅,似乎是makefile方面的问题)。我们可以进入kernel文件夹,发现大量kernel代码:
可以看到对于不同平台,OpenBLAS都准备的大量的kernel代码,这些代码基本都是平台特定的高效汇编代码,我们进入arm64文件夹中,查看KERNEL.ARMV8文件:
。
。
。
。
SGEMMKERNEL = sgemm_kernel_4x4.S
SGEMMONCOPY = ../generic/gemm_ncopy_4.c
SGEMMOTCOPY = ../generic/gemm_tcopy_4.c
SGEMMONCOPYOBJ = sgemm_oncopy.o
SGEMMOTCOPYOBJ = sgemm_otcopy.o
DGEMMKERNEL = ../generic/gemmkernel_2x2.c
DGEMMONCOPY = ../generic/gemm_ncopy_2.c
DGEMMOTCOPY = ../generic/gemm_tcopy_2.c
DGEMMONCOPYOBJ = dgemm_oncopy.o
DGEMMOTCOPYOBJ = dgemm_otcopy.o
CGEMMKERNEL = ../generic/zgemmkernel_2x2.c
CGEMMONCOPY = ../generic/zgemm_ncopy_2.c
CGEMMOTCOPY = ../generic/zgemm_tcopy_2.c
CGEMMONCOPYOBJ = cgemm_oncopy.o
CGEMMOTCOPYOBJ = cgemm_otcopy.o
ZGEMMKERNEL = ../generic/zgemmkernel_2x2.c
ZGEMMONCOPY = ../generic/zgemm_ncopy_2.c
ZGEMMOTCOPY = ../generic/zgemm_tcopy_2.c
ZGEMMONCOPYOBJ = zgemm_oncopy.o
ZGEMMOTCOPYOBJ = zgemm_otcopy.o
。
。
。
。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
可以看到对于特定平台的特定CPU构架,makefile最后会控制这些最热点函数的实现文件,如果这里sgemm_kernel(KERNEL_OPERATION)指向了.S文件,最后编译时就会编译对应的.S文件,然后链接进入库中,而不像在windows环境下使用vs默认的全部使用.c文件实现(所以不使用官方推荐的mingw在windows平台编译的话,编译出来OpenBLAS性能会慢若干倍)。当然,对于一些性能要求不高,或者编写实在麻烦的操作,比如ICOPY_OPERATION,OCOPY_OPERATION,arm64位平台还是指向了generic文件夹,使用了通用的.c实现。
至于平台特定的.S代码如何链接进入最后的库中,我们可以简单查看下sgemm_kernel_4x4.S:
。
。
。
。
。
PROLOGUE
.align 5
add sp, sp, #-(11 * 16)
stp d8, d9, [sp, #(0 * 16)]
stp d10, d11, [sp, #(1 * 16)]
stp d12, d13, [sp, #(2 * 16)]
stp d14, d15, [sp, #(3 * 16)]
stp d16, d17, [sp, #(4 * 16)]
stp x18, x19, [sp, #(5 * 16)]
stp x20, x21, [sp, #(6 * 16)]
stp x22, x23, [sp, #(7 * 16)]
stp x24, x25, [sp, #(8 * 16)]
stp x26, x27, [sp, #(9 * 16)]
str x28, [sp, #(10 * 16)]
fmov alpha0, s0
fmov alpha1, s0
fmov alpha2, s0
fmov alpha3, s0
lsl LDC, LDC, #2 // ldc = ldc * 4
mov pB, origPB
mov counterJ, origN
asr counterJ, counterJ, #2 // J = J / 4
cmp counterJ, #0
ble sgemm_kernel_L2_BEGIN
。
。
。
。
。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
这里的PROLOGUE是一个预编译宏(OpenBLAS到底多喜欢预编译宏啊),最后编译.S的时候会被替换成声明global sgemm_kernel和label名sgemm_kernel,而OpenBLAS的头文件cblas中声明了同样名字的sgemm_kernel函数,于是因为函数名相同,这个汇编代码就被链接进最后的二进制库中了。
这样整个OpenBLAS的代码结构就差不多讲完了,因为OpenBLAS开发者对于预编译宏的谜の执着和对不同平台的特定优化,整个库源代码的结构还是比较复杂的,这就比较阻碍我们进一步去理解他的算法(找个函数半天找不到实现实在是(╯‵□′)╯︵┴─┴ )。OpenBLAS对于矩阵乘法实现的算法也是非常精彩,会在另一篇文章中简单(瞎jb)分析。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)