diff --git a/src/math/mat4.c b/src/math/mat4.c index 3b46aa9..3eb2d23 100644 --- a/src/math/mat4.c +++ b/src/math/mat4.c @@ -164,7 +164,7 @@ Mat4f_t mat4f_scale(const Mat4f_t *__restrict m, float scalar) return mout; } -Mat4f_t* mat4_mul_r(Mat4f_t* out, const Mat4f_t* m2) +Mat4f_t* mat4f_mul_r(Mat4f_t* out, const Mat4f_t* m2) { Mat4f_t clone = mat4f_clone(out); @@ -191,28 +191,38 @@ Mat4f_t* mat4_mul_r(Mat4f_t* out, const Mat4f_t* m2) out->m[row * 4 + col] = mres; } #elif defined (SIMD_ARCH) + float32x4_t mrow = vld1q_f32(&clone.m[row*4]); + + for (int col = 0; col<4; col++) { + float32x4_t mcol = { + m2->m[0 * 4 + col], + m2->m[1 * 4 + col], + m2->m[2 * 4 + col], + m2->m[3 * 4 + col] + }; + + float32x4_t mmul = vmulq_f32(mrow, mcol); + float32x2_t sum_pair = vadd_f32(vget_low_f32(mmul), vget_high_f32(mmul)); + float32x2_t final_sum = vpadd_f32(sum_pair, sum_pair); + + float mres = vget_lane_f32(final_sum, 0); + out->m[row * 4 + col] = mres; + } #else + for (int col = 0; col < 4; col++) { + float sum = 0.0f; + for (int k = 0; k < 4; k++) { + sum += clone.m[row * 4 + k] * m2->m[k * 4 + col]; + } + out->m[row * 4 + col] = sum; + } #endif } return out; } - -// Mat4_t mat4_mul(const Mat4_t* m1, const Mat4_t* m2) -// { -// Mat4_t mat; - -// for(int i = 0; i<4; i++) { -// int i3 = i * 3; -// for (int j = 0; j < 4; j++) { -// float sum = 0; - -// for (int k = 0; k < 3; k++) { -// sum += m1->m[i3 + k] * m2->m[k*3 + j]; -// } - -// mat.m[i3 + j] = sum; -// } -// } - -// return mat; -// } +Mat4f_t mat4_mul(const Mat4f_t* m1, const Mat4f_t* m2) +{ + Mat4f_t mout = mat4f_clone(m1); + mat4f_mul_r(&mout, m2); + return mout; +} \ No newline at end of file diff --git a/src/math/mat4.h b/src/math/mat4.h index 1f575aa..fe25bea 100644 --- a/src/math/mat4.h +++ b/src/math/mat4.h @@ -30,8 +30,8 @@ Mat4f_t mat4f_scale(const Mat4f_t *__restrict m, float scalar); Mat4f_t* mat4f_scale_r(Mat4f_t *out, float scalar); // row * col -Mat4f_t mat4_mul(const Mat4f_t* m1, const Mat4f_t* m2); -Mat4f_t* mat4_mul_r(Mat4f_t* out, const Mat4f_t* m2); +Mat4f_t mat4f_mul(const Mat4f_t* m1, const Mat4f_t* m2); +Mat4f_t* mat4f_mul_r(Mat4f_t* out, const Mat4f_t* m2); // Mat4_t mat4_tpo(const Mat4_t* m);