前言

今天来介绍一下DarkNet中卷积层的前向传播和反向传播的实现,卷积层是卷积神经网络中的核心组件,了解它的底层代码实现对我们理解卷积神经网络以及优化卷积神经网络都有一些帮助。

卷积层的构造

卷积层的构造主要在src/convolutional_layer.c中的make_convolutional_layer中进行实现,下面给出部分核心代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/*
** batch 每个batch含有的图片数
** step
** h 图像高度(行数)
** w 图像宽度(列数)
** c 输入图像通道数
** n 卷积核个数
** groups 分组数
** size 卷积核尺寸
** stride 步长
** dilation 空洞卷积空洞率
** padding 四周补0长度
** activation 激活函数类别
** batch_normalize 是否进行BN
** binary 是否对权重进行二值化
** xnor 是否对权重以及输入进行二值化
** adam 优化方式
** use_bin_output
** index 分组卷积的时候分组索引
** antialiasing 抗锯齿标志,如果为真强行设置所有的步长为1
** share_layer 标志参数,表示这一个卷积层是否和其它卷积层共享权重
** assisted_excitation
** deform 暂时不知道
** train 标志参数,是否在训练
*/
convolutional_layer make_convolutional_layer(int batch, int steps, int h, int w, int c, int n, int groups, int size, int stride_x, int stride_y, int dilation, int padding, ACTIVATION activation,
int batch_normalize, int binary, int xnor, int adam, int use_bin_output, int index, int antialiasing, convolutional_layer *share_layer, int assisted_excitation, int deform, int train)
{
int total_batch = batch*steps;
int i;
convolutional_layer l = { (LAYER_TYPE)0 };
l.type = CONVOLUTIONAL;
l.train = train;

if (xnor) groups = 1; //对于二值网络,不能使用分组卷积
if (groups < 1) groups = 1;

const int blur_stride_x = stride_x;
const int blur_stride_y = stride_y;
l.antialiasing = antialiasing;
if (antialiasing) {
stride_x = stride_y = l.stride = l.stride_x = l.stride_y = 1; // use stride=1 in host-layer
}

l.deform = deform;
l.assisted_excitation = assisted_excitation;
l.share_layer = share_layer;
l.index = index;
l.h = h;
l.w = w;
l.c = c;
l.groups = groups;
l.n = n;
l.binary = binary;
l.xnor = xnor;
l.use_bin_output = use_bin_output;
l.batch = batch;
l.steps = steps;
l.stride = stride_x;
l.stride_x = stride_x;
l.stride_y = stride_y;
l.dilation = dilation;
l.size = size;
l.pad = padding;
l.batch_normalize = batch_normalize;
l.learning_rate_scale = 1;
// 该卷积层总的权重元素个数(权重元素个数等于输入数据的通道数/分组数*卷积核个数*卷积核的二维尺寸,注意因为每一个卷积核是同时作用于输入数据
// 的多个通道上的,因此实际上卷积核是三维的,包括两个维度的平面尺寸,以及输入数据通道数这个维度,每个通道上的卷积核参数都是独立的训练参数)
l.nweights = (c / groups) * n * size * size;
// 如果是共享卷积层,可以直接用共享的卷积层来赋值(猜测是有预训练权重的时候可以直接赋值)
if (l.share_layer) {
if (l.size != l.share_layer->size || l.nweights != l.share_layer->nweights || l.c != l.share_layer->c || l.n != l.share_layer->n) {
printf("Layer size, nweights, channels or filters don't match for the share_layer");
getchar();
}

l.weights = l.share_layer->weights;
l.weight_updates = l.share_layer->weight_updates;

l.biases = l.share_layer->biases;
l.bias_updates = l.share_layer->bias_updates;
}
else {
// 该卷积层总的权重元素(卷积核元素)个数=输入图像通道数 / 分组数*卷积核个数*卷积核尺寸
l.weights = (float*)xcalloc(l.nweights, sizeof(float));
// bias就是Wx+b中的b(上面的weights就是W),有多少个卷积核,就有多少个b(与W的个数一一对应,每个W的元素个数为c*size*size)
l.biases = (float*)xcalloc(n, sizeof(float));
// 训练期间,需要执行反向传播
if (train) {
// 敏感图和特征图的尺寸应该是一样的
l.weight_updates = (float*)xcalloc(l.nweights, sizeof(float));
// bias的敏感图,维度和bias一致
l.bias_updates = (float*)xcalloc(n, sizeof(float));
}
}

// float scale = 1./sqrt(size*size*c);
// 初始化权重:缩放因子*标准正态分布随机数,缩放因子等于sqrt(2./(size*size*c)),随机初始化
// 此处初始化权重为正态分布,而在全连接层make_connected_layer()中初始化权重是均匀分布的。
// TODO:个人感觉,这里应该加一个if条件语句:if(weightfile),因为如果导入了预训练权重文件,就没有必要这样初始化了(事实上在detector.c的train_detector()函数中,
// 紧接着parse_network_cfg()函数之后,就添加了if(weightfile)语句判断是否导入权重系数文件,如果导入了权重系数文件,也许这里初始化的值也会覆盖掉,
// 总之这里的权重初始化的处理方式还是值得思考的,也许更好的方式是应该设置专门的函数进行权重的初始化,同时偏置也是,不过这里似乎没有考虑偏置的初始化,在make_connected_layer()中倒是有。。。)
float scale = sqrt(2./(size*size*c/groups));
for(i = 0; i < l.nweights; ++i) l.weights[i] = scale*rand_uniform(-1, 1); // rand_normal();
// 根据该层输入图像的尺寸、卷积核尺寸以及跨度计算输出特征图的宽度和高度
int out_h = convolutional_out_height(l);
int out_w = convolutional_out_width(l);
// 输出图像高度
l.out_h = out_h;
// 输出图像宽度
l.out_w = out_w;
// 输出图像通道数(等于卷积核个数,有多少个卷积核,最终就得到多少张特征图,每张特征图是一个通道)
l.out_c = n;
l.outputs = l.out_h * l.out_w * l.out_c; // 对应每张输入图片的所有输出特征图的总元素个数(每张输入图片会得到n也即l.out_c张特征图)
l.inputs = l.w * l.h * l.c; // mini-batch中每张输入图片的像素元素个数
l.activation = activation;

l.output = (float*)xcalloc(total_batch*l.outputs, sizeof(float)); // l.output为该层所有的输出(包括mini-batch所有输入图片的输出)
#ifndef GPU
if (train) l.delta = (float*)xcalloc(total_batch*l.outputs, sizeof(float)); // l.delta 该层的敏感度图,和输出的维度想同
#endif // not GPU

// 卷积层三种指针函数,对应三种计算:前向,反向,更新
l.forward = forward_convolutional_layer;
l.backward = backward_convolutional_layer;
l.update = update_convolutional_layer;

卷积层前向传播的代码解析

代码在src/convolutional_layer.c中,注释如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
// 卷积层的前向传播核心代码
void forward_convolutional_layer(convolutional_layer l, network_state state)
{
int out_h = convolutional_out_height(l);
int out_w = convolutional_out_width(l);
int i, j;
// l.outputs = l.out_h * l.out_w * l.out_c在make各网络层函数中赋值(比如make_convolutional_layer()),
// 对应每张输入图片的所有输出特征图的总元素个数(每张输入图片会得到n也即l.out_c张特征图)
// 初始化输出l.output全为0.0;输入l.outputs*l.batch为输出的总元素个数,其中l.outputs为batch
// 中一个输入对应的输出的所有元素的个数,l.batch为一个batch输入包含的图片张数;0表示初始化所有输出为0;
fill_cpu(l.outputs*l.batch, 0, l.output, 1);

// 是否进行二值化操作
if (l.xnor && (!l.align_bit_weights || state.train)) {
if (!l.align_bit_weights || state.train) {
binarize_weights(l.weights, l.n, l.nweights, l.binary_weights);
//printf("\n binarize_weights l.align_bit_weights = %p \n", l.align_bit_weights);
}
swap_binary(&l);
binarize_cpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input);
state.input = l.binary_input;
}

int m = l.n / l.groups; // 该层的卷积核个数
int k = l.size*l.size*l.c / l.groups; // 该层每个卷积核的参数元素个数
int n = out_h*out_w; // 该层每个特征图的尺寸(元素个数)

static int u = 0;
u++;
// 该循环即为卷积计算核心代码:所有卷积核对batch中每张图片进行卷积运算
// 每次循环处理一张输入图片(所有卷积核对batch中一张图片做卷积运算)
for(i = 0; i < l.batch; ++i)
{
// 该循环是为了处理分组卷积
for (j = 0; j < l.groups; ++j)
{
// 当前组卷积核(也即权重),元素个数为l.n*l.c/l.groups*l.size*l.size,
// 共有l.n行,l.c/l.gropus,l.c*l.size*l.size列
float *a = l.weights +j*l.nweights / l.groups;
// 对输入图像进行重排之后的图像数据,所以内存空间申请为网络中最大占用内存
float *b = state.workspace;
// 存储一张输入图片(多通道)当前组的输出特征图(输入图片是多通道的,输出
// 图片也是多通道的,有多少组卷积核就有多少组通道,每个分组后的卷积核得到一张特征图即为一个通道)
// 这里似乎有点拗口,可以看下分组卷积原理。
float *c = l.output +(i*l.groups + j)*n*m;

//gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
//gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n);
//二值网络,特殊处理,里面还有一些优化,细节很多,这里暂时不管二值网络这部分,把注意力先放在普通卷积层的计算上
if (l.xnor && l.align_bit_weights && !state.train && l.stride_x == l.stride_y)
{
memset(b, 0, l.bit_align*l.size*l.size*l.c * sizeof(float));

if (l.c % 32 == 0)
{
//printf(" l.index = %d - new XNOR \n", l.index);

int ldb_align = l.lda_align;
size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
//size_t t_intput_size = new_ldb * l.bit_align;// n;
//size_t t_bit_input_size = t_intput_size / 8;// +1;

int re_packed_input_size = l.c * l.w * l.h;
memset(state.workspace, 0, re_packed_input_size * sizeof(float));

const size_t new_c = l.c / 32;
size_t in_re_packed_input_size = new_c * l.w * l.h + 1;
memset(l.bin_re_packed_input, 0, in_re_packed_input_size * sizeof(uint32_t));

//float *re_packed_input = calloc(l.c * l.w * l.h, sizeof(float));
//uint32_t *bin_re_packed_input = calloc(new_c * l.w * l.h + 1, sizeof(uint32_t));

// float32x4 by channel (as in cuDNN)
repack_input(state.input, state.workspace, l.w, l.h, l.c);

// 32 x floats -> 1 x uint32_t
float_to_bit(state.workspace, (unsigned char *)l.bin_re_packed_input, l.c * l.w * l.h);

//free(re_packed_input);

// slow - convolution the packed inputs and weights: float x 32 by channel (as in cuDNN)
//convolution_repacked((uint32_t *)bin_re_packed_input, (uint32_t *)l.align_bit_weights, l.output,
// l.w, l.h, l.c, l.n, l.size, l.pad, l.new_lda, l.mean_arr);

// // then exit from if()


im2col_cpu_custom((float *)l.bin_re_packed_input, new_c, l.h, l.w, l.size, l.stride, l.pad, state.workspace);
//im2col_cpu((float *)bin_re_packed_input, new_c, l.h, l.w, l.size, l.stride, l.pad, b);

//free(bin_re_packed_input);

int new_k = l.size*l.size*l.c / 32;

// good for (l.c == 64)
//gemm_nn_bin_32bit_packed(m, n, new_k, 1,
// l.align_bit_weights, l.new_lda/32,
// b, n,
// c, n, l.mean_arr);

// // then exit from if()

transpose_uint32((uint32_t *)state.workspace, (uint32_t*)l.t_bit_input, new_k, n, n, new_ldb);

// the main GEMM function
gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char*)l.align_bit_weights, new_ldb, (unsigned char*)l.t_bit_input, new_ldb, c, n, l.mean_arr);

// // alternative GEMM
//gemm_nn_bin_transposed_32bit_packed(m, n, new_k, 1,
// l.align_bit_weights, l.new_lda/32,
// t_bit_input, new_ldb / 32,
// c, n, l.mean_arr);

//free(t_bit_input);

}
else
{ // else (l.c % 32 != 0)

//--------------------------------------------------------
//printf(" l.index = %d - old XNOR \n", l.index);

//im2col_cpu_custom_align(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b, l.bit_align);
im2col_cpu_custom_bin(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, state.workspace, l.bit_align);

//size_t output_size = l.outputs;
//float *count_output = calloc(output_size, sizeof(float));
//size_t bit_output_size = output_size / 8 + 1;
//char *bit_output = calloc(bit_output_size, sizeof(char));

//size_t intput_size = n * k; // (out_h*out_w) X (l.size*l.size*l.c) : after im2col()
//size_t bit_input_size = intput_size / 8 + 1;
//char *bit_input = calloc(bit_input_size, sizeof(char));

//size_t weights_size = k * m; //l.size*l.size*l.c*l.n; // l.nweights
//size_t bit_weights_size = weights_size / 8 + 1;

//char *bit_weights = calloc(bit_weights_size, sizeof(char));
//float *mean_arr = calloc(l.n, sizeof(float));

// transpose B from NxK to KxN (x-axis (ldb = l.size*l.size*l.c) - should be multiple of 8 bits)
{
//size_t ldb_align = 256; // 256 bit for AVX2
int ldb_align = l.lda_align;
size_t new_ldb = k + (ldb_align - k%ldb_align);
size_t t_intput_size = binary_transpose_align_input(k, n, state.workspace, &l.t_bit_input, ldb_align, l.bit_align);

// 5x times faster than gemm()-float32
gemm_nn_custom_bin_mean_transposed(m, n, k, 1, (unsigned char*)l.align_bit_weights, new_ldb, (unsigned char*)l.t_bit_input, new_ldb, c, n, l.mean_arr);

//gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr);

//free(t_input);
//free(t_bit_input);
//}
}

}

add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);

//activate_array(l.output, m*n*l.batch, l.activation);
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output);
else if (l.activation == NORM_CHAN_SOFTMAX) activate_array_normalize_channels_softmax(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output, 0);
else if (l.activation == NORM_CHAN_SOFTMAX_MAXVAL) activate_array_normalize_channels_softmax(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output, 1);
else activate_array_cpu_custom(l.output, m*n*l.batch, l.activation);
return;

}
else {
//printf(" l.index = %d - FP32 \n", l.index);
// 由于有分组卷积,所以获取属于当前组的输入im并按一定存储规则排列的数组b,
// 以方便、高效地进行矩阵(卷积)计算,详细查看该函数注释(比较复杂)
// 这里的im实际上只加载了一张图片的数据
// 关于im2col的原理我会讲
float *im = state.input + (i*l.groups + j)*(l.c / l.groups)*l.h*l.w;
// 如果这里卷积核尺寸为1,是不需要改变内存排布方式
if (l.size == 1) {
b = im;
}
else {
//im2col_cpu(im, l.c / l.groups, l.h, l.w, l.size, l.stride, l.pad, b);
// 将多通道二维图像im变成按一定存储规则排列的数组b,
// 以方便、高效地进行矩阵(卷积)计算,详细查看该函数注释(比较复杂)
// 进行重排,l.c/groups为每张图片的通道数分组,l.h为每张图片的高度,l.w为每张图片的宽度,l.size为卷积核尺寸,l.stride为步长
// 得到的b为一张图片重排后的结果,也是按行存储的一维数组(共有l.c/l.groups*l.size*l.size行,l.out_w*l.out_h列)
im2col_cpu_ext(im, // input
l.c / l.groups, // input channels
l.h, l.w, // input size (h, w)
l.size, l.size, // kernel size (h, w)
l.pad, l.pad, // padding (h, w)
l.stride_y, l.stride_x, // stride (h, w)
l.dilation, l.dilation, // dilation (h, w)
b); // output

}
// 此处在im2col_cpu操作基础上,利用矩阵乘法c=alpha*a*b+beta*c完成对图像卷积的操作
// 0,0表示不对输入a,b进行转置,
// m是输入a,c的行数,具体含义为每个卷积核的个数,
// n是输入b,c的列数,具体含义为每个输出特征图的元素个数(out_h*out_w),
// k是输入a的列数也是b的行数,具体含义为卷积核元素个数乘以输入图像的通道数除以分组数(l.size*l.size*l.c/l.groups),
// a,b,c即为三个参与运算的矩阵(用一维数组存储),alpha=beta=1为常系数,
// a为所有卷积核集合,元素个数为l.n*l.c/l.groups*l.size*l.size,按行存储,共有l*n行,l.c/l.groups*l.size*l.size列,
// 即a中每行代表一个可以作用在3通道上的卷积核,
// b为一张输入图像经过im2col_cpu重排后的图像数据(共有l.c/l.group*l.size*l.size行,l.out_w*l.out_h列),
// c为gemm()计算得到的值,包含一张输入图片得到的所有输出特征图(每个卷积核得到一张特征图),c中一行代表一张特征图,
// 各特征图铺排开成一行后,再将所有特征图并成一大行,存储在c中,因此c可视作有l.n行,l.out_h*l.out_w列。
// 详细查看该函数注释(比较复杂)
gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n);
// bit-count to float
}
//c += n*m;
//state.input += l.c*l.h*l.w;
}
}
// 如果卷积层使用了BatchNorm,那么执行forward_batchnorm,如果没有,则添加偏置
if(l.batch_normalize){
forward_batchnorm_layer(l, state);
}
else {
add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
}

//activate_array(l.output, m*n*l.batch, l.activation);
// 使用不同的激活函数
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output);
else if (l.activation == NORM_CHAN_SOFTMAX) activate_array_normalize_channels_softmax(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output, 0);
else if (l.activation == NORM_CHAN_SOFTMAX_MAXVAL) activate_array_normalize_channels_softmax(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output, 1);
else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation);
// 二值网络,前向传播结束之后转回float
if(l.binary || l.xnor) swap_binary(&l);

//visualize_convolutional_layer(l, "conv_visual", NULL);
//wait_until_press_key_cv();
// 暂时不懂
if(l.assisted_excitation && state.train) assisted_excitation_forward(l, state);
// 暂时不懂
if (l.antialiasing) {
network_state s = { 0 };
s.train = state.train;
s.workspace = state.workspace;
s.net = state.net;
s.input = l.output;
forward_convolutional_layer(*(l.input_layer), s);
//simple_copy_ongpu(l.outputs*l.batch, l.output, l.input_antialiasing);
memcpy(l.output, l.input_layer->output, l.input_layer->outputs * l.input_layer->batch * sizeof(float));
}
}

im2col解析

从上面的代码可以知道,卷积层的前向传播核心点是im2col操作还有sgemm矩阵计算方法对使用im2col进行重排后的数据进行计算。现在来解析一下im2col算法,sgemm算法就是im2col运行后直接调用即可,就不细讲了。

这里考虑到结合图片更容易理解im2col的思想,我利用CSDN Tiger-Gao博主的图描述一下。首先,我们把一个单通道的长宽均为4的图片通过im2col重新排布后会变成什么样呢?看下图:

来具体看一下变化过程:

这是单通道的变化过程,那么多通道的呢?首先来看原图:

多通道的im2col的过程,是首先im2col第一通道,然后再im2col第二通道,最后im2col第三通道。各通道im2col的数据在内存中也是连续存储的。看下图:

这是原图经过im2col的变化,那么kernel呢?看原图:

kernel的通道数据在内存中也是连续存储的。所以上面的kernel图像经过im2col算法后可以表示为下图:

那么我们是如何得到前向传播的结果呢?在DarkNet中和Caffe的实现方式一样,都是Kernel*Img,即是在矩阵乘法中:

1
2
3
M=1 
N=output_h * output_w
K=input_channels * kernel_h * kernel_w

结果如下:

图像数据是连续存储,因此输出图像也可以如下图所示【output_h x output_w】=【2 x 2】:

对于多通道图像过程类似:

同样,多个输出通道图像的数据是连续存储,因此输出图像也可以如下图所示【output_channels x output_h x output_w】=【3 x 2 x 2】

im2col算法的实现在src/im2col.c中,即im2col_cpu函数。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/*
** 从输入的多通道数组im(存储图像数据)中获取指定行,列,通道数处的元素值
** im: 函数的输入,所有的数据存成一个一维数组
** height: 每一个通道的高度(即是输入图像的真正高度,补0之前)
** width: 每一个通道的宽度(即是输入图像的真正宽度,补0之前)
** channles:输入通道数
** row: 要提取的元素所在的行(padding之后的行数)
** col: 要提取的元素所在的列(padding之后的列数)
** channel: 要提取的元素所在的通道
** pad: 图像上下左右补0的个数,四周是一样的
** 返回im中channel通道,row-pad行,col-pad列处的元素值
** 在im中并没有存储补0的元素值,因此height,width都是没有补0时输入图像真正的高、宽;
** 而row与col则是补0之后,元素所在的行列,因此,要准确获取在im中的元素值,首先需要
** 减去pad以获取在im中真实的行列数
*/
float im2col_get_pixel(float *im, int height, int width, int channels,
int row, int col, int channel, int pad)
{
//减去补0长度,获取像素真实的行列数
row -= pad;
col -= pad;
// 如果行列数<0,或者超过height/width,则返回0(刚好是补0的效果)
if (row < 0 || col < 0 ||
row >= height || col >= width) return 0;
// im存储多通道二维图像的数据格式为: 各个通道所有的所有行并成1行,再多通道依次并成一行
// 因此width*height*channel首先移位到所在通道的起点位置,再加上width*row移位到所在指定
// 通道行,再加上col移位到所在列
return im[col + width*(row + height*channel)];
}

//From Berkeley Vision's Caffe!
//https://github.com/BVLC/caffe/blob/master/LICENSE


/*
** 将输入图片转为便于计算的数组格式
** data_im: 输入图像
** height: 输入图像的高度(行)
** width: 输入图像的宽度(列)
** ksize: 卷积核尺寸
** stride: 卷积核跨度
** pad: 四周补0的长度
** data_col: 相当于输出,为进行格式重排后的输入图像数据
** 输出data_col的元素个数与data_im个数不相等,一般比data_im个数多,因为stride较小,各个卷积核之间有很多重叠,
** 实际data_col中的元素个数为channels*ksize*ksize*height_col*width_col,其中channels为data_im的通道数,
** ksize为卷积核大小,height_col和width_col如下所注。data_col的还是按行排列,只是行数为channels*ksize*ksize,
** 列数为height_col*width_col,即一张特征图总的元素个数,每整列包含与某个位置处的卷积核计算的所有通道上的像素,
** (比如输入图像通道数为3,卷积核尺寸为3*3,则共有27行,每列有27个元素),不同列对应卷积核在图像上的不同位置做卷积
*/
void im2col_cpu(float* data_im,
int channels, int height, int width,
int ksize, int stride, int pad, float* data_col)
{
int c,h,w;
// 计算该层神经网络的输出图像尺寸(其实没有必要再次计算的,因为在构建卷积层时,make_convolutional_layer()函数
// 已经调用convolutional_out_width(),convolutional_out_height()函数求取了这两个参数,
// 此处直接使用l.out_h,l.out_w即可,函数参数只要传入该层网络指针就可了,没必要弄这么多参数)
int height_col = (height + 2*pad - ksize) / stride + 1;
int width_col = (width + 2*pad - ksize) / stride + 1;
// 卷积核大小:ksize*ksize是一个卷积核的大小,之所以乘以通道数channels,是因为输入图像有多通道,每个卷积核在做卷积时,
// 是同时对同一位置多通道的图像进行卷积运算,这里为了实现这一目的,将三个通道将三通道上的卷积核并在一起以便进行计算,因此卷积核
// 实际上并不是二维的,而是三维的,比如对于3通道图像,卷积核尺寸为3*3,该卷积核将同时作用于三通道图像上,这样并起来就得
// 到含有27个元素的卷积核,且这27个元素都是独立的需要训练的参数。所以在计算训练参数个数时,一定要注意每一个卷积核的实际
// 训练参数需要乘以输入通道数。
int channels_col = channels * ksize * ksize;
// 外循环次数为一个卷积核的尺寸数,循环次数即为最终得到的data_col的总行数
for (c = 0; c < channels_col; ++c) {
// 列偏移,卷积核是一个二维矩阵,并按行存储在一维数组中,利用求余运算获取对应在卷积核中的列数,比如对于
// 3*3的卷积核(3通道),当c=0时,显然在第一列,当c=5时,显然在第2列,当c=9时,在第二通道上的卷积核的第一列,
// 当c=26时,在第三列(第三通道上)
int w_offset = c % ksize;
// 行偏移,卷积核是一个二维的矩阵,且是按行(卷积核所有行并成一行)存储在一维数组中的,
// 比如对于3*3的卷积核,处理3通道的图像,那么一个卷积核具有27个元素,每9个元素对应一个通道上的卷积核(互为一样),
// 每当c为3的倍数,就意味着卷积核换了一行,h_offset取值为0,1,2,对应3*3卷积核中的第1, 2, 3行
int h_offset = (c / ksize) % ksize;
// 通道偏移,channels_col是多通道的卷积核并在一起的,比如对于3通道,3*3卷积核,每过9个元素就要换一通道数,
// 当c=0~8时,c_im=0;c=9~17时,c_im=1;c=18~26时,c_im=2
int c_im = c / ksize / ksize;
// 中循环次数等于该层输出图像行数height_col,说明data_col中的每一行存储了一张特征图,这张特征图又是按行存储在data_col中的某行中
for (h = 0; h < height_col; ++h) {
// 内循环等于该层输出图像列数width_col,说明最终得到的data_col总有channels_col行,height_col*width_col列
for (w = 0; w < width_col; ++w) {
// 由上面可知,对于3*3的卷积核,h_offset取值为0,1,2,当h_offset=0时,会提取出所有与卷积核第一行元素进行运算的像素,
// 依次类推;加上h*stride是对卷积核进行行移位操作,比如卷积核从图像(0,0)位置开始做卷积,那么最先开始涉及(0,0)~(3,3)
// 之间的像素值,若stride=2,那么卷积核进行一次行移位时,下一行的卷积操作是从元素(2,0)(2为图像行号,0为列号)开始
int im_row = h_offset + h * stride;
// 对于3*3的卷积核,w_offset取值也为0,1,2,当w_offset取1时,会提取出所有与卷积核中第2列元素进行运算的像素,
// 实际在做卷积操作时,卷积核对图像逐行扫描做卷积,加上w*stride就是为了做列移位,
// 比如前一次卷积其实像素元素为(0,0),若stride=2,那么下次卷积元素起始像素位置为(0,2)(0为行号,2为列号)
int im_col = w_offset + w * stride;
// col_index为重排后图像中的像素索引,等于c * height_col * width_col + h * width_col +w(还是按行存储,所有通道再并成一行),
// 对应第c通道,h行,w列的元素
int col_index = (c * height_col + h) * width_col + w;
// im2col_get_pixel函数获取输入图像data_im中第c_im通道,im_row,im_col的像素值并赋值给重排后的图像,
// height和width为输入图像data_im的真实高、宽,pad为四周补0的长度(注意im_row,im_col是补0之后的行列号,
// 不是真实输入图像中的行列号,因此需要减去pad获取真实的行列号)
data_col[col_index] = im2col_get_pixel(data_im, height, width, channels,
im_row, im_col, c_im, pad);
}
}
}
}