; matrix multiplication ; author: Dmitri Kuvshinov ; MASM syntax, x86-64, SSE2-only ; Beware, this code seems to be correct but has not been tested really thoroughly, ; no correctness guarantees are provided. .code public addmtxmul addmtxmul proc ; A_ptr, B_ptr, C_ptr, N, S, M ; A is NxS, B is SxM, C is NxM OPTION PROLOGUE:NONE, EPILOGUE:NONE ; rcx = param A_ptr ; a_row ; rdx = param B_ptr ; const ; r8 = param C_ptr ; c_row ; r9 = param N ; const ; [rsp+40] = param S -> r10 ; const ; [rsp+48] = param M -> r11 ; const ; xmm0 ; aij x2 ; xmm1, xmm2, xmm3, xmm4 -- temporaries // 4 numbers per round ; PROLOGUE mov r10, QWORD PTR [rsp+40] ; 32-byte shadow space + 8-byte return address = 40 bytes mov r11, QWORD PTR [rsp+48] sub rsp, 40 mov QWORD PTR [rsp], r13 mov QWORD PTR [rsp+8], r12 mov QWORD PTR [rsp+16], rbx mov QWORD PTR [rsp+24], rsi mov QWORD PTR [rsp+32], rdi ; BODY ; use more convenient register names mov rax, rcx ; a_row -> rax ; b_data = rdx const; b_row -> rbx mov rcx, r8 ; c_row -> rcx mov r13, r11 ; M -> r13 const mov r12, r10 ; S -> r12 const mov r11, r9 ; N -> r11 const mov rsi, r13 and rsi, 3 ; rsi = m % 4 mov rdi, r13 and rdi, -4 ; rdi = m - m % 4 ; i -> r8 ; j -> r9 ; k -> r10 ; for i xor r8, r8 ; i = 0 i_loop: mov rbx, rdx ; b_row = b->data ; for j xor r9, r9 ; j = 0 j_loop: xorps xmm0, xmm0 movsd xmm0, QWORD PTR [rax + r9*8] ; aij = a_row[j] unpcklpd xmm0, xmm0 ; for k xor r10, r10 ; k = 0 test rdi, rdi jz k_loop_4_end k_loop_4_begin: movupd xmm2, XMMWORD PTR [rbx + r10*8] ; first pair b_row[k:k+1] movupd xmm4, XMMWORD PTR [rbx + r10*8 + 16] ; second pair b_row[k+2:k+3] movupd xmm1, XMMWORD PTR [rcx + r10*8] ; c_row[k:k+1] movupd xmm3, XMMWORD PTR [rcx + r10*8 + 16] ; c_row[k+2:k+3] add r10, 4 ; k += 4 mulpd xmm2, xmm0 mulpd xmm4, xmm0 addpd xmm1, xmm2 addpd xmm3, xmm4 movupd XMMWORD PTR [rcx + r10*8 - 32], xmm1 movupd XMMWORD PTR [rcx + r10*8 - 16], xmm3 cmp r10, rdi ; k ? m - m % 4 jne k_loop_4_begin ; == --> break k_loop_4_end: test rsi, rsi jnz k_loop_tail_works ; deals with the rest 1--3 elements in the row k_loop_tail: add r9, 1 ; ++j lea rbx, QWORD PTR [rbx + r13*8] ; b_row += m cmp r9, r12 ; j ? s jne j_loop ; != -- repeat add r8, 1 ; ++i lea rax, QWORD PTR [rax + r12*8] ; a_row += s lea rcx, QWORD PTR [rcx + r13*8] ; c_row += m cmp r8, r11 ; i ? n jne i_loop ; != -- repeat ; EPILOGUE mov r13, QWORD PTR [rsp] mov r12, QWORD PTR [rsp+8] mov rbx, QWORD PTR [rsp+16] mov rsi, QWORD PTR [rsp+24] mov rdi, QWORD PTR [rsp+32] add rsp, 40 ret k_loop_tail_works: movapd xmm2, xmm0 movapd xmm1, xmm0 mulsd xmm2, QWORD PTR [rbx + r10*8] addsd xmm2, QWORD PTR [rcx + r10*8] movsd QWORD PTR [rcx + r10*8], xmm2 add r10, 1 cmp rsi, 1 je k_loop_tail movapd xmm2, xmm0 mulsd xmm1, QWORD PTR [rbx + r10*8] addsd xmm1, QWORD PTR [rcx + r10*8] movsd QWORD PTR [rcx + r10*8], xmm1 add r10, 1 cmp rsi, 2 je k_loop_tail mulsd xmm2, QWORD PTR [rbx + r10*8] addsd xmm2, QWORD PTR [rcx + r10*8] movsd QWORD PTR [rcx + r10*8], xmm2 jmp k_loop_tail addmtxmul endp end