forked from pi7mcrg2k/operator_optimization
				
			
							parent
							
								
									5d750743a6
								
							
						
					
					
						commit
						93d0eae3ae
					
				| @ -0,0 +1,70 @@ | |||||||
|  | #include <stdio.h> | ||||||
|  | #include <time.h> | ||||||
|  | #include <stdlib.h> | ||||||
|  | #define SIZE 1024 | ||||||
|  | 
 | ||||||
|  | void sparce_matmul_coo(float*, int*, int*, int, | ||||||
|  | 	float*, int*, int*, int, float*, int*, int*, int*); | ||||||
|  | int main() { | ||||||
|  | 	// 矩阵 A 的 COO 格式
 | ||||||
|  | 	float A_values[] = {1, 2, 3, 4, 5};  | ||||||
|  | 	int A_rowIndex[] = {0, 0, 1, 2, 2};  | ||||||
|  | 	int A_colIndex[] = {0, 2, 1, 0, 2};  | ||||||
|  | 	int A_nonZeroCount = 5; | ||||||
|  | 	// 矩阵 B 的 Coo 格式
 | ||||||
|  | 	float B_values[] = {6, 8, 7, 9};  | ||||||
|  | 	int B_rowIndex[] = {0, 2, 1, 2};  | ||||||
|  | 	int B_colIndex[] = {0, 0, 1, 2};  | ||||||
|  | 	int B_nonZeroCount = 4;// 结果矩阵 C 的 Coo 格式
 | ||||||
|  | 	float C_values[SIZE];  | ||||||
|  | 	int C_rowIndex[SIZE];  | ||||||
|  | 	int C_colIndex[SIZE];  | ||||||
|  | 	int C_nonZeroCount = 0; | ||||||
|  | 
 | ||||||
|  | 	clock_t start = clock(); | ||||||
|  | 	sparce_matmul_coo(A_values, A_rowIndex, A_colIndex, A_nonZeroCount, | ||||||
|  | 	                  B_values, B_rowIndex, B_colIndex, B_nonZeroCount, | ||||||
|  | 					  C_values, C_rowIndex, C_colIndex, &C_nonZeroCount); | ||||||
|  | 	clock_t end = clock(); | ||||||
|  | 	printf("基础的稀疏矩阵乘法时间:%lf\n", (double)(end-start) / CLOCKS_PER_SEC); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void sparce_matmul_coo(float* A_values, int* A_rowIndex, int* A_colIndex, int A_nonZeroCount, | ||||||
|  |                        float* B_values, int* B_rowIndex, int* B_colIndex, int B_nonZeroCount, | ||||||
|  | 					   float* C_values, int* C_rowIndex, int* C_colIndex, int* C_nonZeroCount) { | ||||||
|  |     int currentIndex = 0; | ||||||
|  | 	int i, j, k; | ||||||
|  | 	int rowA, colA, rowB, colB; | ||||||
|  | 	float valueA, valueB, product; | ||||||
|  | 	// 遍历 A 的非零元素
 | ||||||
|  | 	for(i=0; i<A_nonZeroCount; i++) { | ||||||
|  | 		rowA = A_rowIndex[i]; | ||||||
|  | 		colA = A_colIndex[i]; | ||||||
|  | 		valueA = A_values[i]; | ||||||
|  | 		// 遍历 B 的非零元素
 | ||||||
|  | 		for(j=0; j<B_nonZeroCount; j++) { | ||||||
|  | 			rowB = B_rowIndex[j]; | ||||||
|  | 			colB = B_colIndex[j]; | ||||||
|  | 			valueB = B_values[j]; | ||||||
|  | 			// 如果 A 的列和 B 的行匹配,则计算乘积并存储结果
 | ||||||
|  | 			if (colA == rowB) { | ||||||
|  | 				product = valueA * valueB; | ||||||
|  | 				// 检查是否已有此(rowA, colB) 项
 | ||||||
|  | 				int found = 0; | ||||||
|  | 				for(k=0; k<currentIndex; k++) { | ||||||
|  | 					if(C_rowIndex[k] == rowA && C_colIndex[k] == colB) { | ||||||
|  | 						C_values[k] += product; | ||||||
|  | 						break; | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				if (!found) { | ||||||
|  | 					C_values[currentIndex] = product; | ||||||
|  | 					C_rowIndex[currentIndex] = rowA; | ||||||
|  | 					C_colIndex[currentIndex] = colB; | ||||||
|  | 					currentIndex++; | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	*C_nonZeroCount = currentIndex; | ||||||
|  | } | ||||||
					Loading…
					
					
				
		Reference in new issue