You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
5.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

// 如果定义了__ARM_NEON__则包含<arm_neon.h>头文件用于支持ARM NEON指令集相关的操作
#ifdef __ARM_NEON__
#include <arm_neon.h>
#endif
// 定义向量的大小为1024
const int VECTOR_SIZE = 1024;
// 定义矩阵的大小为32这里表示矩阵是32x32的
const int MATRIX_SIZE = 32;
// 基础的向量加法函数
// 该函数实现了简单的向量加法操作将输入向量a和b对应位置的元素相加结果存储在result向量中
// 参数:
// a输入向量a的指针
// b输入向量b的指针
// result用于存储相加结果的向量指针
// size向量的元素个数
void vectorAddBase(float* a, float* b, float* result, int size) {
// 循环遍历向量的每个元素
for (int i = 0; i < size; ++i) {
// 将a和b对应位置的元素相加并存入result向量的对应位置
result[i] = a[i] + b[i];
}
}
// 如果定义了__ARM_NEON__则执行以下NEON优化的向量加法函数
#ifdef __ARM_NEON__
// NEON优化的向量加法函数
// 该函数利用ARM NEON指令集对向量加法进行优化以4个元素为一组进行处理当向量大小是4的倍数时
// 参数:
// a输入向量a的指针
// b输入向量b的指针
// result用于存储相加结果的向量指针
// size向量的元素个数
void vectorAddNeon(float* a, float* b, float* result, int size) {
int i;
// 以4个元素为一组进行处理循环直到剩余元素不足4个
for (i = 0; i <= size - 4; i += 4) {
// 使用vld1q_f32指令将向量a中从索引i开始的4个元素加载到NEON寄存器vecA中
float32x4_t vecA = vld1q_f32(&a[i]);
// 使用vld1q_f32指令将向量b中从索引i开始的4个元素加载到NEON寄存器vecB中
float32x4_t vecB = vld1q_f32(&b[i]);
// 使用vaddq_f32指令对vecA和vecB中的对应元素进行加法操作结果存储在vecResult中
float32x4_t vecResult = vaddq_f32(vecA, vecB);
// 使用vst1q_f32指令将vecResult中的4个元素存储回result向量中从索引i开始的位置
vst1q_f32(&result[i], vecResult);
}
// 处理向量大小不是4的倍数时剩余的元素
for (; i < size; ++i) {
// 对于剩余元素使用基础的加法方式将a和b对应位置的元素相加并存入result向量的对应位置
result[i] = a[i] + b[i];
}
}
// 如果未定义__ARM_NEON__则执行以下fallback操作即调用基础的向量加法函数
#else
void vectorAddNeon(float* a, float* b, float* result, int size) {
vectorAddBase(a, b, result, size); // Fallback to base implementation
}
#endif
// 基础的矩阵乘法函数
// 该函数实现了常规的矩阵乘法算法按照矩阵乘法的规则计算两个矩阵A和B的乘积结果存储在矩阵C中
// 参数:
// A输入矩阵A的指针
// B输入矩阵B的指针
// C用于存储乘法结果的矩阵指针
// N矩阵的边长这里假设矩阵是方阵所以只需要一个边长参数
void matrixMultiplyBase(float* A, float* B, float* C, int N) {
// 外层循环遍历矩阵A的行
for (int i = 0; i < N; ++i) {
// 内层循环遍历矩阵B的列
for (int j = 0; j < N; ++j) {
// 先将结果矩阵C中当前位置的元素初始化为0.0f
C[i * N + j] = 0.0f;
// 中间循环遍历矩阵A的列和矩阵B的行用于计算乘积并累加
for (int k = 0; k < N; ++k) {
C[i * N + j] += A[i * N + k] * B[k * N + j];
}
}
}
}
// 如果定义了__ARM_NEON__则执行以下NEON优化的矩阵乘法函数
#ifdef __ARM_NEON__
// NEON优化的矩阵乘法函数
// 该函数利用ARM NEON指令集对矩阵乘法进行优化以4个元素为一组处理矩阵B的列当矩阵边长是4的倍数时
// 参数:
// A输入矩阵A的指针
// B输入矩阵B的指针
// C用于存储乘法结果的矩阵指针
// N矩阵的边长这里假设矩阵是方阵所以只需要一个边长参数
void matrixMultiplyNeon(float* A, float* B, float* C, int N) {
// 外层循环遍历矩阵A的行
for (int i = 0; i < N; ++i) {
// 内层循环遍历矩阵B的列每次处理4个元素一组
for (int j = 0; j < N; j += 4) {
// 使用vmovq_n_f32指令将浮点数0.0f初始化为NEON寄存器sum用于累加乘积结果
float32x4_t sum = vmovq_n_f32(0.0f);
// 中间循环遍历矩阵A的列和矩阵B的行用于计算乘积并累加
for (int k = 0; k < N; ++k) {
// 使用vdupq_n_f32指令复制矩阵A中当前行、当前列k对应的元素到NEON寄存器vecA中
float32x4_t vecA = vdupq_n_f32(A[i * N + k]);
// 使用vld1q_f32指令将矩阵B中当前列j从索引k开始的4个元素加载到NEON寄存器vecB中
float32x4_t vecB = vld1q_f32(&B[k * N + j]);
// 使用vfmaq_f32指令将vecA和vecB中的对应元素相乘并累加到sum寄存器中
sum = vfmaq_f32(sum, vecA, vecB);
}
// 使用vst1q_f32指令将sum寄存器中的4个元素存储回结果矩阵C中当前行、当前列j开始的位置
vst1q_f32(&C[i * N + j], sum);
}
}
}
// 如果未定义__ARM_NEON__则执行以下fallback操作即调用基础的矩阵乘法函数
#else
void matrixMultiplyNeon(float* A, float* B, float* C, int N) {
matrixMultiplyBase(A, B, C, N); // Fallback to base implementation
}
#endif
// 计算两个timespec结构体表示的时间点之间的时间差的函数
// 参数:
// start起始时间点的timespec结构体
// end结束时间点的timespec结构体
// 返回值:以秒为单位的时间差
double get_time_diff(struct timespec start, struct timespec end) {
// 计算秒数差
double diff_sec = end.tv_sec - start.tv_sec;
// 计算纳秒数差
double diff_nsec = end.tv_nsec - start.tv_nsec;
// 将纳秒数差转换为秒并与秒数差相加,得到总时间差
return diff_sec + diff_nsec / 1e9;
}
// 如果定义了_WIN32说明是在Windows平台下包含<windows.h>头文件,并定义以下函数
#ifdef _WIN32
#include <windows.h>
// 计算两个LARGE_INTEGER结构体表示的时间点之间的时间差的函数用于Windows平台
// 参数:
// start起始时间点的LARGE_INTEGER结构体
// end结束时间点的LARGE_INTEGER结构体
// frequency时间频率的LARGE_INTEGER结构体
// 返回值:以秒为单位的时间差
double get_time_diff_windows(LARGE_INTEGER start, LARGE_INTEGER end, LARGE_INTEGER frequency) {
// 计算经过的时间通过两个时间点的QuadPart差值除以时间频率的QuadPart得到
double elapsed = (double)(end.QuadPart - start.QuadPart) / frequency.QuadPart;
// 返回计算得到的时间差
return elapsed;
}
#endif