-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocess.s
More file actions
132 lines (112 loc) · 4.76 KB
/
preprocess.s
File metadata and controls
132 lines (112 loc) · 4.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
/*
Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause-Clear
*/
/*----------------------------------------------------------------------------
* This function is used to re-arrange the elements of input matrix to
* make it suitable for matrix outer product computation using SME for matrix
* multiplication. It should be used to pre-process the leftmatrix(A) in the
* matrix muliplication (C= A*B) using sgemm_direct_sme1_2VLx2VL()
*
* The pre-processing transposes a block of SVLs rows of the input matrix and
* stores it contiguously. The same is applied to remaining blocks of SVLs
* rows. The last block of SVLs rows is zero-padded to SVLs rows if needed.
*
* Usage of function:
* sgemm_direct_sme1_preprocess(uint64_t nrow, uint64_t ncol, \
* const float * restrict mat, float * mat_mod);
*
----------------------------------------------------------------------------*/
#define nrow x0 //Number of rows of input matrix
#define ncol x1 //Number of coulumns of input matrix
#define mat x2 //Input matrix base address
#define mat_mod x3 //Output matrix (re-arranged matrix) base address
#define mat_mod_ptr x4 //Pointer to output matrix
#define mat_ptr0 x5 //Pointer to input matrix
#define mat_ptr1 x6 //2nd pointer to input matrix
#define outer_loop_cntr x7 //Outer loop counter
#define inner_loop_exit x8 //Inner loop exit condition
#define C1 x9 //Constant1: SVLs - No. of 32-bit elements
#define C2 x10 //Constant2: 3*SVLs
#define C3 x11 //Constant3: ncol*SVLs
#define C4 x13 //Constant4: 2*SVLs
#define C5 x14 //Constant5: 2*ncol
#define C6 x15 //Constant6: 3*ncol
.text
.global _sgemm_direct_sme1_preprocess
_sgemm_direct_sme1_preprocess:
stp x19, x20, [sp, #-48]!
stp x21, x22, [sp, #16]
stp x23, x24, [sp, #32]
smstart
cntw C1 //SVLs
mul C3, C1, ncol //SVLs*ncol
lsl C5, ncol, #1 //2*ncol
add C6, C5, ncol //3*ncol
cnth C4 //2*SVLs
add C2, C1, C1, lsl #1 //3*SVLs
mov outer_loop_cntr, #0
//Tile predicate (M dimension)
whilelt p0.s, outer_loop_cntr, nrow
//Predicate for stores
ptrue p9.s
.M_Loop:
mov mat_ptr0, mat //Load base address of mat
mov mat_mod_ptr, mat_mod //a_mod store base address
add inner_loop_exit, mat, ncol, lsl #2 //Exit condition for inner loop
whilelt p8.b, mat_ptr0, inner_loop_exit //Tile predicate (K dimension)
.Loop_process:
mov mat_ptr1, mat_ptr0
//Load_to_tile loop counter
mov w12, #0
.Load_to_tile:
psel p2, p8, p0.s[w12, 0]
psel p3, p8, p0.s[w12, 1]
psel p4, p8, p0.s[w12, 2]
psel p5, p8, p0.s[w12, 3]
//Load 1st row from mat_ptr1
ld1w {za0h.s[w12, #0]}, p2/z, [mat_ptr1]
//Load 2nd row from mat_ptr1 + ncol
ld1w {za0h.s[w12, #1]}, p3/z, [mat_ptr1, ncol, lsl #2]
//Load 3rd row from mat_ptr1 + 2*ncol
ld1w {za0h.s[w12, #2]}, p4/z, [mat_ptr1, C5, lsl #2]
//Load 4th row from mat_ptr1 + 3*ncol
ld1w {za0h.s[w12, #3]}, p5/z, [mat_ptr1, C6, lsl #2]
//mat_ptr1+=4*ncol FP32 elements
add mat_ptr1, mat_ptr1, ncol, lsl #4
//Increment counter
add w12, w12, #4
cmp w12, w9
b.mi .Load_to_tile
// Store_from_tile loop counter
mov w12, #0
.Store_from_tile:
psel p2, p9, p8.s[w12, 0]
psel p3, p9, p8.s[w12, 1]
psel p4, p9, p8.s[w12, 2]
psel p5, p9, p8.s[w12, 3]
//Store 1st col to mat_mod
st1w {za0v.s[w12, #0]}, p2, [mat_mod_ptr]
//Store 2nd col to mat_mod + SVLs
st1w {za0v.s[w12, #1]}, p3, [mat_mod_ptr, C1, lsl #2]
//Store 3rd col to mat_mod + 2*SVLs
st1w {za0v.s[w12, #2]}, p4, [mat_mod_ptr, C4, lsl #2]
//Store 4th col to mat_mod + 3*SVLs
st1w {za0v.s[w12, #3]}, p5, [mat_mod_ptr, C2, lsl #2]
addvl mat_mod_ptr, mat_mod_ptr, #4 //mat_mod_ptr += 4*SVLb
add w12, w12, #4 //Increment counter
cmp w12, w9
b.mi .Store_from_tile
addvl mat_ptr0, mat_ptr0, #1 //mat_ptr0 += SVLb
whilelt p8.b, mat_ptr0, inner_loop_exit
b.first .Loop_process
add mat_mod, mat_mod, C3, lsl #2 //mat_mod+=SVLs*nbc FP32 elements
add mat, mat, C3, lsl #2 //mat+=SVLs*nbc FP32 elements
incw outer_loop_cntr
whilelt p0.s, outer_loop_cntr, nrow
b.first .M_Loop
smstop
ldp x23, x24, [sp, #32]
ldp x21, x22, [sp, #16]
ldp x19, x20, [sp], #48
ret