You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <stdio.h>
#include <ctime>
#define MAX 1024
//因为所有变量都是局部变量,未采用动态分配内存的办法,故本代码文件不手动清理内存
void sparse_matmul_coo(float* A_values,int* A_rowIndex, int* A_colIndex, int A_nonZeroCount,
float* B_values, int* B_rowIndex, int* B_colIndex,int B_nonZeroCount,
float* C_values, int* C_rowIndex, int* C_colIndex,int* C_nonZeroCount)
{
int currentIndex = 0;
//遍历 A 非零元素
for (int i =0 ; i < A_nonZeroCount; i++)
{
int rowA = A_rowIndex[i];
int colA = A_colIndex[i];
float valueA = A_values[i];
// 遍历 B 的非零元素
for (int j =0 ; j < B_nonZeroCount; j++)
{
int rowB = B_rowIndex[j];
int colB = B_colIndex[j];
float valueB = B_values[j];
// 如果 A 的列和 B 的行匹配,则计算乘积并存储结果
if (colA == rowB)
{
float product = valueA * valueB;
// 检查是否已有此 (rowAcolB) 项
int found = 0;
for (int k =0; k < currentIndex; k++)
{
if (C_rowIndex[k] == rowA && C_colIndex[k] == colB)
{
C_values[k] += product;
found = 1;
break;
}
}
// 如果没有此项,添加新的非零元素
if (!found)
{
C_values[currentIndex] = product;
C_rowIndex[currentIndex] = rowA;
C_colIndex[currentIndex]= colB;
currentIndex++;
}
}
}
}
//更新非零元素数量
*C_nonZeroCount=currentIndex;
}
int main()
{
//矩阵 A 的 COO 格式
float A_values[] = {1, 2, 3,4,5};
int A_rowIndex[] = {0, 0, 1, 2, 2};
int A_colIndex[] = {0, 2, 1,0, 2};
int A_nonZeroCount = 5;
// 矩阵 B 的 COO 格式
float B_values[] = {6,8,7,9};
int B_rowIndex[] = {0,2, 1, 2};
int B_colIndex[] ={0,0,1, 2};
int B_nonZeroCount=4;
// 结果矩阵 的 Coo 格式
float C_values[MAX];
int C_rowIndex[MAX];
int C_colIndex[MAX];
int C_nonZeroCount =0 ;
clock_t start = clock();
sparse_matmul_coo(A_values,A_rowIndex,A_colIndex,A_nonZeroCount,
B_values,B_rowIndex,B_colIndex,B_nonZeroCount,
C_values,C_rowIndex,C_colIndex,&C_nonZeroCount) ;
clock_t end = clock();
int rowA=0;int colA=0;
int rowB=0;int colB=0;
for(int i=0;i<sizeof(A_rowIndex) / sizeof(A_rowIndex[0]);i++)
{
if(A_rowIndex[i]>rowA)
{
rowA=A_rowIndex[i];
}
}
for(int i=0;i<sizeof(A_colIndex) / sizeof(A_colIndex[0]);i++)
{
if(A_colIndex[i]>colA)
{
colA=A_colIndex[i];
}
}
for(int i=0;i<sizeof(B_rowIndex) / sizeof(B_rowIndex[0]);i++)
{
if(B_rowIndex[i]>rowB)
{
rowB=B_rowIndex[i];
}
}
for(int i=0;i<sizeof(B_colIndex) / sizeof(B_colIndex[0]);i++)
{
if(B_colIndex[i]>colB)
{
colB=B_colIndex[i];
}
}
printf("矩阵A的规模为%d*%d,矩阵B的规模为%d*%d\n",rowA+1,colA+1,rowB+1,colB+1);
// 计算并输出基础的稀疏矩阵乘法的时间
double coo_time_spent = double(end - start) / CLOCKS_PER_SEC;
printf("基础的稀疏矩阵乘法时间:%lf秒\n", coo_time_spent);
}