; matrix multiplication
; author: Dmitri Kuvshinov
; MASM syntax, x86-64, SSE2-only
; Beware, this code seems to be correct but has not been tested really thoroughly,
; no correctness guarantees are provided.
.code
public addmtxmul
addmtxmul proc ; A_ptr, B_ptr, C_ptr, N, S, M
; A is NxS, B is SxM, C is NxM
OPTION PROLOGUE:NONE, EPILOGUE:NONE
; rcx = param A_ptr ; a_row
; rdx = param B_ptr ; const
; r8 = param C_ptr ; c_row
; r9 = param N ; const
; [rsp+40] = param S -> r10 ; const
; [rsp+48] = param M -> r11 ; const
; xmm0 ; aij x2
; xmm1, xmm2, xmm3, xmm4 -- temporaries // 4 numbers per round
; PROLOGUE
mov r10, QWORD PTR [rsp+40] ; 32-byte shadow space + 8-byte return address = 40 bytes
mov r11, QWORD PTR [rsp+48]
sub rsp, 40
mov QWORD PTR [rsp], r13
mov QWORD PTR [rsp+8], r12
mov QWORD PTR [rsp+16], rbx
mov QWORD PTR [rsp+24], rsi
mov QWORD PTR [rsp+32], rdi
; BODY
; use more convenient register names
mov rax, rcx ; a_row -> rax
; b_data = rdx const; b_row -> rbx
mov rcx, r8 ; c_row -> rcx
mov r13, r11 ; M -> r13 const
mov r12, r10 ; S -> r12 const
mov r11, r9 ; N -> r11 const
mov rsi, r13
and rsi, 3 ; rsi = m % 4
mov rdi, r13
and rdi, -4 ; rdi = m - m % 4
; i -> r8
; j -> r9
; k -> r10
; for i
xor r8, r8 ; i = 0
i_loop:
mov rbx, rdx ; b_row = b->data
; for j
xor r9, r9 ; j = 0
j_loop:
xorps xmm0, xmm0
movsd xmm0, QWORD PTR [rax + r9*8] ; aij = a_row[j]
unpcklpd xmm0, xmm0
; for k
xor r10, r10 ; k = 0
test rdi, rdi
jz k_loop_4_end
k_loop_4_begin:
movupd xmm2, XMMWORD PTR [rbx + r10*8] ; first pair b_row[k:k+1]
movupd xmm4, XMMWORD PTR [rbx + r10*8 + 16] ; second pair b_row[k+2:k+3]
movupd xmm1, XMMWORD PTR [rcx + r10*8] ; c_row[k:k+1]
movupd xmm3, XMMWORD PTR [rcx + r10*8 + 16] ; c_row[k+2:k+3]
add r10, 4 ; k += 4
mulpd xmm2, xmm0
mulpd xmm4, xmm0
addpd xmm1, xmm2
addpd xmm3, xmm4
movupd XMMWORD PTR [rcx + r10*8 - 32], xmm1
movupd XMMWORD PTR [rcx + r10*8 - 16], xmm3
cmp r10, rdi ; k ? m - m % 4
jne k_loop_4_begin ; == --> break
k_loop_4_end:
test rsi, rsi
jnz k_loop_tail_works ; deals with the rest 1--3 elements in the row
k_loop_tail:
add r9, 1 ; ++j
lea rbx, QWORD PTR [rbx + r13*8] ; b_row += m
cmp r9, r12 ; j ? s
jne j_loop ; != -- repeat
add r8, 1 ; ++i
lea rax, QWORD PTR [rax + r12*8] ; a_row += s
lea rcx, QWORD PTR [rcx + r13*8] ; c_row += m
cmp r8, r11 ; i ? n
jne i_loop ; != -- repeat
; EPILOGUE
mov r13, QWORD PTR [rsp]
mov r12, QWORD PTR [rsp+8]
mov rbx, QWORD PTR [rsp+16]
mov rsi, QWORD PTR [rsp+24]
mov rdi, QWORD PTR [rsp+32]
add rsp, 40
ret
k_loop_tail_works:
movapd xmm2, xmm0
movapd xmm1, xmm0
mulsd xmm2, QWORD PTR [rbx + r10*8]
addsd xmm2, QWORD PTR [rcx + r10*8]
movsd QWORD PTR [rcx + r10*8], xmm2
add r10, 1
cmp rsi, 1
je k_loop_tail
movapd xmm2, xmm0
mulsd xmm1, QWORD PTR [rbx + r10*8]
addsd xmm1, QWORD PTR [rcx + r10*8]
movsd QWORD PTR [rcx + r10*8], xmm1
add r10, 1
cmp rsi, 2
je k_loop_tail
mulsd xmm2, QWORD PTR [rbx + r10*8]
addsd xmm2, QWORD PTR [rcx + r10*8]
movsd QWORD PTR [rcx + r10*8], xmm2
jmp k_loop_tail
addmtxmul endp
end