[Ventus][Work-item] Optimize Work-item built-in functions

This commit is contained in:
Jules-Kong 2024-11-04 13:22:18 +08:00
parent 0f31d6ca08
commit 5032490e31
3 changed files with 17 additions and 138 deletions

View File

@ -33,6 +33,13 @@
#define CSR_GID_Y 0x809 // group_id_y #define CSR_GID_Y 0x809 // group_id_y
#define CSR_GID_Z 0x80a // group_id_z #define CSR_GID_Z 0x80a // group_id_z
#define CSR_PRINT 0x80b // for print buffer #define CSR_PRINT 0x80b // for print buffer
#define CSR_GL_ID_X 0x80d // global_id_x
#define CSR_GL_ID_Y 0x80e // global_id_y
#define CSR_GL_ID_Z 0x80f // global_id_z
#define CSR_GLL_ID 0x810 // global_linear_id
#define CSR_LC_ID_X 0x811 // local_id_x
#define CSR_LC_ID_Y 0x812 // local_id_y
#define CSR_LC_ID_Z 0x813 // local_id_z
// Kernel metadata buffer offsets // Kernel metadata buffer offsets
#define KNL_ENTRY 0 #define KNL_ENTRY 0

View File

@ -1,20 +1,7 @@
#include <clc/clc.h> #include <clc/clc.h>
extern size_t __builtin_riscv_global_linear_id();
_CLC_DEF _CLC_OVERLOAD size_t get_global_linear_id() { _CLC_DEF _CLC_OVERLOAD size_t get_global_linear_id() {
uint dim = get_work_dim() - 1; return __builtin_riscv_global_linear_id();
switch (dim) {
case 0:
return get_global_id(0) - get_global_offset(0);
;
case 1:
return (get_global_id(1) - get_global_offset(1)) * get_global_size(0) +
(get_global_id(0) - get_global_offset(0));
case 2:
return ((get_global_id(2) - get_global_offset(2)) * get_global_size(1) +
(get_global_id(1) - get_global_offset(1))) *
get_global_size(0) +
(get_global_id(0) - get_global_offset(0));
default:
return 0;
}
} }

View File

@ -69,37 +69,7 @@ _local_id_z:
.global __builtin_riscv_global_linear_id .global __builtin_riscv_global_linear_id
.type __builtin_riscv_global_linear_id, @function .type __builtin_riscv_global_linear_id, @function
__builtin_riscv_global_linear_id: __builtin_riscv_global_linear_id:
addi sp, sp, 4 csrr.v v0, CSR_GLL_ID # Read global_linear_id
sw ra, -4(sp)
csrr a3, CSR_KNL # Get kernel metadata buffer
lw t0, KNL_WORK_DIM(a3) # Get work_dims
call __builtin_riscv_global_id_x
lw t4, KNL_GL_OFFSET_X(a3) # global_offset_x
vsub.vx v5, v0, t4 # global_linear_id1
li t5, 1
beq t0, t5, .GLR # Return global_linear_id for 1 dim
.GL_2DIM:
call __builtin_riscv_global_id_y
lw t6, KNL_GL_SIZE_X(a3) # global_size_x
lw t5, KNL_GL_OFFSET_Y(a3) # global_offset_y
vsub.vx v6, v0, t5 # tmp = global_id_y - global_offset_y
vmul.vx v6, v6, t6 # tmp = tmp * global_size_x
vadd.vv v5, v5, v6 # global_linear_id2 = tmp + global_linear_id1
li t5, 2
beq t0, t5, .GLR # Return global_linear_id for 2 dim
.GL_3DIM:
call __builtin_riscv_global_id_z
lw t6, KNL_GL_SIZE_X(a3) # global_size_x
lw t1, KNL_GL_SIZE_Y(a3) # global_size_y
lw t5, KNL_GL_OFFSET_Z(a3) # global_offset_z
vsub.vx v6, v0, t5 # tmp = global_id_z - global_offset_z
vmul.vx v6, v6, t6 # tmp = tmp * global_size_x
vmul.vx v6, v6, t1 # tmp = tmp * global_size_y
vadd.vv v5, v5, v6 # global_linear_id3 = tmp + global_linear_id2
.GLR:
vadd.vx v0, v5, zero # Return global_linear_id for 1/2/3 dims
lw ra, -4(sp)
addi sp, sp, -4
ret ret
@ -134,17 +104,7 @@ __builtin_riscv_workgroup_id_z:
.global __builtin_riscv_workitem_id_x .global __builtin_riscv_workitem_id_x
.type __builtin_riscv_workitem_id_x, @function .type __builtin_riscv_workitem_id_x, @function
__builtin_riscv_workitem_id_x: __builtin_riscv_workitem_id_x:
addi sp, sp, 4 csrr.v v0, CSR_LC_ID_X # Read local_id_x
sw ra, -4(sp)
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, KNL_WORK_DIM(a0) # Get work_dim
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_id_x in 1 dim (local_linear_id)
lw t3, KNL_LC_SIZE_X(a0) # local_size_x
vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x
lw ra, -4(sp)
addi sp, sp, -4
ret ret
@ -152,29 +112,7 @@ __builtin_riscv_workitem_id_x:
.global __builtin_riscv_workitem_id_y .global __builtin_riscv_workitem_id_y
.type __builtin_riscv_workitem_id_y, @function .type __builtin_riscv_workitem_id_y, @function
__builtin_riscv_workitem_id_y: __builtin_riscv_workitem_id_y:
addi sp, sp, 4 csrr.v v0, CSR_LC_ID_Y # Read local_id_y
sw ra, -4(sp)
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, KNL_WORK_DIM(a0) # Get work_dim
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_linear_id
lw t3, KNL_LC_SIZE_X(a0) # local_size_x offset in 2 work_dims
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y offset in 2 work_dims
mul t5, t4, t3 # local_size_x * local_size_y
vremu.vx v0, v0, t5 # x = local_linear_id % (local_size_x * local_size_y)
vdivu.vx v0, v0, t3 # x / local_size_x
vmv.v.x v1, t4
.hi2:
auipc t1, %pcrel_hi(.end2)
setrpc zero, t1, %pcrel_lo(.hi2)
vblt v0, v1, .end2
li t5, -1
vadd.vx v0, v1, t5
.end2:
join zero, zero, 0
lw ra, -4(sp)
addi sp, sp, -4
ret ret
@ -182,28 +120,7 @@ __builtin_riscv_workitem_id_y:
.global __builtin_riscv_workitem_id_z .global __builtin_riscv_workitem_id_z
.type __builtin_riscv_workitem_id_z, @function .type __builtin_riscv_workitem_id_z, @function
__builtin_riscv_workitem_id_z: __builtin_riscv_workitem_id_z:
addi sp, sp, 4 csrr.v v0, CSR_LC_ID_Z # Read local_id_z
sw ra, -4(sp)
csrr a0, CSR_KNL # Get kernel metadata buffer
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_linear_id
lw t3, KNL_LC_SIZE_X(a0) # local_size_x
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y
lw t5, KNL_LC_SIZE_Z(a0) # local_size_z
mul t4, t4, t3 # local_size_x * local_size_y
vdivu.vx v0, v0, t4 # local_linear_id / (local_size_x * local_size_y)
vmv.v.x v1, t5
.hi3:
auipc t1, %pcrel_hi(.end3)
setrpc zero, t1, %pcrel_lo(.hi3)
vblt v0, v1, .end3
li t5, -1
vadd.vx v0, v1, t5
.end3:
join zero, zero, 0
lw ra, -4(sp)
addi sp, sp, -4
ret ret
@ -211,18 +128,7 @@ __builtin_riscv_workitem_id_z:
.global __builtin_riscv_global_id_x .global __builtin_riscv_global_id_x
.type __builtin_riscv_global_id_x, @function .type __builtin_riscv_global_id_x, @function
__builtin_riscv_global_id_x: __builtin_riscv_global_id_x:
addi sp, sp, 4 csrr.v v0, CSR_GL_ID_X # Read global_id_x
sw ra, -4(sp)
call __builtin_riscv_workitem_id_x
csrr a0, CSR_KNL # Get kernel metadata buffer
csrr t1, CSR_GID_X # Get group_id_x
lw t3, KNL_LC_SIZE_X(a0) # Get local_size_x
lw t4, KNL_GL_OFFSET_X(a0) # Get global_offset_x
mul t6, t1, t3 # CSR_GID_X * local_size_x
add t6, t6, t4 # Get global_offset_x + CSR_GID_X * local_size_x
vadd.vx v0,v0, t6
lw ra, -4(sp)
addi sp, sp, -4
ret ret
@ -230,17 +136,7 @@ __builtin_riscv_global_id_x:
.global __builtin_riscv_global_id_y .global __builtin_riscv_global_id_y
.type __builtin_riscv_global_id_y, @function .type __builtin_riscv_global_id_y, @function
__builtin_riscv_global_id_y: __builtin_riscv_global_id_y:
addi sp, sp, 4 csrr.v v0, CSR_GL_ID_Y # Read global_id_y
sw ra, -4(sp)
call __builtin_riscv_workitem_id_y
csrr t1, CSR_GID_Y # Get group_id_y
lw t2, KNL_LC_SIZE_Y(a0) # Get local_size_y
lw t4, KNL_GL_OFFSET_Y(a0) # Get global_offset_y
mul t3, t1, t2 # CSR_GID_Y * local_size_y
add t3, t3, t4 # global_offset_y + (CSR_GID_Y * local_size_y)
vadd.vx v0, v0, t3 # global_id_y
lw ra, -4(sp)
addi sp, sp, -4
ret ret
@ -248,18 +144,7 @@ __builtin_riscv_global_id_y:
.global __builtin_riscv_global_id_z .global __builtin_riscv_global_id_z
.type __builtin_riscv_global_id_z, @function .type __builtin_riscv_global_id_z, @function
__builtin_riscv_global_id_z: __builtin_riscv_global_id_z:
addi sp, sp, 4 csrr.v v0, CSR_GL_ID_Z # Read global_id_z
sw ra, -4(sp)
call __builtin_riscv_workitem_id_z
csrr a0, CSR_KNL # Get kernel metadata buffer
csrr t1, CSR_GID_Z # Get group_id_z
lw t2, KNL_LC_SIZE_Z(a0) # Get local_size_z
lw t3, KNL_GL_OFFSET_Z(a0) # Get global_offset_z
mul t2, t2, t1 # CSR_GID_Z * local_size_z
add t2, t2, t3 # global_offset_z + (CSR_GID_Z * local_size_z)
vadd.vx v0, v0, t2 # global_id_z
lw ra, -4(sp)
addi sp, sp, -4
ret ret