You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
4.8 KiB

#define _CRT_SECURE_NO_WARNINGS
#include <stdio.h>
#include <stdlib.h>
#include <arm_neon.h>
#include <time.h>
#define ROWS 4
#define COLS 4
typedef struct {
float* values;
int* rowIndex;
int* colIndex;
int nonZeroCount;
} SparseMatrix;
SparseMatrix* createSparseMatrix(int nonZeroCount) {
SparseMatrix* matrix = (SparseMatrix*)malloc(sizeof(SparseMatrix));
matrix->values = (float*)malloc(nonZeroCount * sizeof(float));
matrix->rowIndex = (int*)malloc(nonZeroCount * sizeof(int));
matrix->colIndex = (int*)malloc(nonZeroCount * sizeof(int));
matrix->nonZeroCount = nonZeroCount;
return matrix;
}
void freeSparseMatrix(SparseMatrix* matrix) {
free(matrix->values);
free(matrix->rowIndex);
free(matrix->colIndex);
free(matrix);
}
void sparseToDense(SparseMatrix* sparse, float dense[ROWS][COLS]) {
for (int i = 0; i < ROWS; i++) {
for (int j = 0; j < COLS; j++) {
dense[i][j] = 0.0f;
}
}
for (int i = 0; i < sparse->nonZeroCount; i++) {
int row = sparse->rowIndex[i];
int col = sparse->colIndex[i];
dense[row][col] = sparse->values[i];
}
}
void printDenseMatrix(float dense[ROWS][COLS]) {
for (int i = 0; i < ROWS; i++) {
for (int j = 0; j < COLS; j++) {
printf("%5.1f ", dense[i][j]);
}
printf("\n");
}
}
SparseMatrix* sparse_matmul(SparseMatrix* A, SparseMatrix* B) {
int maxNonZeroCount = A->nonZeroCount * B->nonZeroCount;
SparseMatrix* C = createSparseMatrix(maxNonZeroCount);
int count = 0;
for (int i = 0; i < A->nonZeroCount; i++) {
float aValue = A->values[i];
int aRow = A->rowIndex[i];
int aCol = A->colIndex[i];
for (int j = 0; j < B->nonZeroCount; j++) {
if (aCol == B->rowIndex[j]) {
float bValue = B->values[j];
int bCol = B->colIndex[j];
int found = 0;
for (int k = 0; k < count; k++) {
if (C->rowIndex[k] == aRow && C->colIndex[k] == bCol) {
C->values[k] += aValue * bValue;
found = 1;
break;
}
}
if (!found) {
C->values[count] = aValue * bValue;
C->rowIndex[count] = aRow;
C->colIndex[count] = bCol;
count++;
}
}
}
}
C->nonZeroCount = count;
C->values = (float*)realloc(C->values, count * sizeof(float));
C->rowIndex = (int*)realloc(C->rowIndex, count * sizeof(int));
C->colIndex = (int*)realloc(C->colIndex, count * sizeof(int));
return C;
}
void neonSparseMatMul(SparseMatrix* A, SparseMatrix* B, float C[ROWS][COLS]) {
for (int i = 0; i < ROWS; i++) {
for (int j = 0; j < COLS; j++) {
C[i][j] = 0.0f;
}
}
for (int i = 0; i < A->nonZeroCount; i++) {
float aValue = A->values[i];
int aRow = A->rowIndex[i];
int aCol = A->colIndex[i];
for (int j = 0; j < B->nonZeroCount; j++) {
if (aCol == B->rowIndex[j]) {
float bValue = B->values[j];
int bCol = B->colIndex[j];
float32x4_t cValue = vld1q_f32(&C[aRow][bCol]);
cValue = vmlaq_n_f32(cValue, vdupq_n_f32(aValue), bValue);
vst1q_f32(&C[aRow][bCol], cValue);
}
}
}
}
int main() {
SparseMatrix* A = createSparseMatrix(4);
A->values[0] = 1.0; A->rowIndex[0] = 0; A->colIndex[0] = 0;
A->values[1] = 2.0; A->rowIndex[1] = 0; A->colIndex[1] = 2;
A->values[2] = 3.0; A->rowIndex[2] = 1; A->colIndex[2] = 1;
A->values[3] = 4.0; A->rowIndex[3] = 2; A->colIndex[3] = 0;
SparseMatrix* B = createSparseMatrix(4);
B->values[0] = 5.0; B->rowIndex[0] = 0; B->colIndex[0] = 1;
B->values[1] = 6.0; B->rowIndex[1] = 1; B->colIndex[1] = 2;
B->values[2] = 7.0; B->rowIndex[2] = 2; B->colIndex[2] = 0;
B->values[3] = 8.0; B->rowIndex[3] = 2; B->colIndex[3] = 1;
float C[ROWS][COLS];
clock_t start = clock();
neonSparseMatMul(A, B, C);
clock_t end = clock();
double time_taken = (double)(end - start) / CLOCKS_PER_SEC;
float denseA[ROWS][COLS], denseB[ROWS][COLS];
sparseToDense(A, denseA);
sparseToDense(B, denseB);
printf("<EFBFBD><EFBFBD>ͨ<EFBFBD><EFBFBD><EFBFBD><EFBFBD> A:\n");
printDenseMatrix(denseA);
printf("<EFBFBD><EFBFBD>ͨ<EFBFBD><EFBFBD><EFBFBD><EFBFBD> B:\n");
printDenseMatrix(denseB);
printf("ϡ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>˷<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>:\n");
printDenseMatrix(C);
printf("ϡ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>˷<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʱ<EFBFBD><EFBFBD>: %f <20><>\n", time_taken);
freeSparseMatrix(A);
freeSparseMatrix(B);
return 0;
}