forked from pi7mcrg2k/operator_optimization
				
			
							parent
							
								
									5d750743a6
								
							
						
					
					
						commit
						93d0eae3ae
					
				| @ -0,0 +1,70 @@ | ||||
| #include <stdio.h> | ||||
| #include <time.h> | ||||
| #include <stdlib.h> | ||||
| #define SIZE 1024 | ||||
| 
 | ||||
| void sparce_matmul_coo(float*, int*, int*, int, | ||||
| 	float*, int*, int*, int, float*, int*, int*, int*); | ||||
| int main() { | ||||
| 	// 矩阵 A 的 COO 格式
 | ||||
| 	float A_values[] = {1, 2, 3, 4, 5};  | ||||
| 	int A_rowIndex[] = {0, 0, 1, 2, 2};  | ||||
| 	int A_colIndex[] = {0, 2, 1, 0, 2};  | ||||
| 	int A_nonZeroCount = 5; | ||||
| 	// 矩阵 B 的 Coo 格式
 | ||||
| 	float B_values[] = {6, 8, 7, 9};  | ||||
| 	int B_rowIndex[] = {0, 2, 1, 2};  | ||||
| 	int B_colIndex[] = {0, 0, 1, 2};  | ||||
| 	int B_nonZeroCount = 4;// 结果矩阵 C 的 Coo 格式
 | ||||
| 	float C_values[SIZE];  | ||||
| 	int C_rowIndex[SIZE];  | ||||
| 	int C_colIndex[SIZE];  | ||||
| 	int C_nonZeroCount = 0; | ||||
| 
 | ||||
| 	clock_t start = clock(); | ||||
| 	sparce_matmul_coo(A_values, A_rowIndex, A_colIndex, A_nonZeroCount, | ||||
| 	                  B_values, B_rowIndex, B_colIndex, B_nonZeroCount, | ||||
| 					  C_values, C_rowIndex, C_colIndex, &C_nonZeroCount); | ||||
| 	clock_t end = clock(); | ||||
| 	printf("基础的稀疏矩阵乘法时间:%lf\n", (double)(end-start) / CLOCKS_PER_SEC); | ||||
| } | ||||
| 
 | ||||
| void sparce_matmul_coo(float* A_values, int* A_rowIndex, int* A_colIndex, int A_nonZeroCount, | ||||
|                        float* B_values, int* B_rowIndex, int* B_colIndex, int B_nonZeroCount, | ||||
| 					   float* C_values, int* C_rowIndex, int* C_colIndex, int* C_nonZeroCount) { | ||||
|     int currentIndex = 0; | ||||
| 	int i, j, k; | ||||
| 	int rowA, colA, rowB, colB; | ||||
| 	float valueA, valueB, product; | ||||
| 	// 遍历 A 的非零元素
 | ||||
| 	for(i=0; i<A_nonZeroCount; i++) { | ||||
| 		rowA = A_rowIndex[i]; | ||||
| 		colA = A_colIndex[i]; | ||||
| 		valueA = A_values[i]; | ||||
| 		// 遍历 B 的非零元素
 | ||||
| 		for(j=0; j<B_nonZeroCount; j++) { | ||||
| 			rowB = B_rowIndex[j]; | ||||
| 			colB = B_colIndex[j]; | ||||
| 			valueB = B_values[j]; | ||||
| 			// 如果 A 的列和 B 的行匹配,则计算乘积并存储结果
 | ||||
| 			if (colA == rowB) { | ||||
| 				product = valueA * valueB; | ||||
| 				// 检查是否已有此(rowA, colB) 项
 | ||||
| 				int found = 0; | ||||
| 				for(k=0; k<currentIndex; k++) { | ||||
| 					if(C_rowIndex[k] == rowA && C_colIndex[k] == colB) { | ||||
| 						C_values[k] += product; | ||||
| 						break; | ||||
| 					} | ||||
| 				} | ||||
| 				if (!found) { | ||||
| 					C_values[currentIndex] = product; | ||||
| 					C_rowIndex[currentIndex] = rowA; | ||||
| 					C_colIndex[currentIndex] = colB; | ||||
| 					currentIndex++; | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	*C_nonZeroCount = currentIndex; | ||||
| } | ||||
					Loading…
					
					
				
		Reference in new issue