You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.1 KiB

5 days ago
#include <stdio.h>
#include <ctime>
#define MAX 1024
//因为所有变量都是局部变量,未采用动态分配内存的办法,故本代码文件不手动清理内存
void sparse_matmul_coo(float* A_values,int* A_rowIndex, int* A_colIndex, int A_nonZeroCount,
float* B_values, int* B_rowIndex, int* B_colIndex,int B_nonZeroCount,
float* C_values, int* C_rowIndex, int* C_colIndex,int* C_nonZeroCount)
{
int currentIndex = 0;
//遍历 A 非零元素
for (int i =0 ; i < A_nonZeroCount; i++)
{
int rowA = A_rowIndex[i];
int colA = A_colIndex[i];
float valueA = A_values[i];
// 遍历 B 的非零元素
for (int j =0 ; j < B_nonZeroCount; j++)
{
int rowB = B_rowIndex[j];
int colB = B_colIndex[j];
float valueB = B_values[j];
// 如果 A 的列和 B 的行匹配,则计算乘积并存储结果
if (colA == rowB)
{
float product = valueA * valueB;
// 检查是否已有此 (rowAcolB) 项
int found = 0;
for (int k =0; k < currentIndex; k++)
{
if (C_rowIndex[k] == rowA && C_colIndex[k] == colB)
{
C_values[k] += product;
found = 1;
break;
}
}
// 如果没有此项,添加新的非零元素
if (!found)
{
C_values[currentIndex] = product;
C_rowIndex[currentIndex] = rowA;
C_colIndex[currentIndex]= colB;
currentIndex++;
}
}
}
}
//更新非零元素数量
*C_nonZeroCount=currentIndex;
}
int main()
{
//矩阵 A 的 COO 格式
float A_values[] = {1, 2, 3,4,5};
int A_rowIndex[] = {0, 0, 1, 2, 2};
int A_colIndex[] = {0, 2, 1,0, 2};
int A_nonZeroCount = 5;
// 矩阵 B 的 COO 格式
float B_values[] = {6,8,7,9};
int B_rowIndex[] = {0,2, 1, 2};
int B_colIndex[] ={0,0,1, 2};
int B_nonZeroCount=4;
// 结果矩阵 的 Coo 格式
float C_values[MAX];
int C_rowIndex[MAX];
int C_colIndex[MAX];
int C_nonZeroCount =0 ;
clock_t start = clock();
sparse_matmul_coo(A_values,A_rowIndex,A_colIndex,A_nonZeroCount,
B_values,B_rowIndex,B_colIndex,B_nonZeroCount,
C_values,C_rowIndex,C_colIndex,&C_nonZeroCount) ;
clock_t end = clock();
int rowA=0;int colA=0;
int rowB=0;int colB=0;
for(int i=0;i<sizeof(A_rowIndex) / sizeof(A_rowIndex[0]);i++)
{
if(A_rowIndex[i]>rowA)
{
rowA=A_rowIndex[i];
}
}
for(int i=0;i<sizeof(A_colIndex) / sizeof(A_colIndex[0]);i++)
{
if(A_colIndex[i]>colA)
{
colA=A_colIndex[i];
}
}
for(int i=0;i<sizeof(B_rowIndex) / sizeof(B_rowIndex[0]);i++)
{
if(B_rowIndex[i]>rowB)
{
rowB=B_rowIndex[i];
}
}
for(int i=0;i<sizeof(B_colIndex) / sizeof(B_colIndex[0]);i++)
{
if(B_colIndex[i]>colB)
{
colB=B_colIndex[i];
}
}
printf("矩阵A的规模为%d*%d,矩阵B的规模为%d*%d\n",rowA+1,colA+1,rowB+1,colB+1);
// 计算并输出基础的稀疏矩阵乘法的时间
double coo_time_spent = double(end - start) / CLOCKS_PER_SEC;
printf("基础的稀疏矩阵乘法时间:%lf秒\n", coo_time_spent);
}