[VENTUS][RISCV][fix] Fix workitem function implementation

This commit is contained in:
zhoujing 2023-07-19 17:45:35 +08:00
parent b8223e72bd
commit 1a6ead3f43
2 changed files with 28 additions and 30 deletions

View File

@ -194,9 +194,6 @@ build_libclc() {
if [ ! -d "${DstDir}" ]; then if [ ! -d "${DstDir}" ]; then
mkdir -p ${DstDir} mkdir -p ${DstDir}
fi fi
# TODO: make this copy process done during libclc build process?
cp ${LIBCLC_BUILD_DIR}/riscv32--.bc ${DstDir}/kernel-riscv32.bc
cp ${LIBCLC_BUILD_DIR}/riscv32--.a ${VENTUS_INSTALL_PREFIX}/lib/libriscv32clc.a
} }
# Build icd_loader # Build icd_loader

View File

@ -17,15 +17,15 @@
* get_local_id(1) = (CSR_TID + vid.v) / local_size_x * get_local_id(1) = (CSR_TID + vid.v) / local_size_x
* *
* 3 dims: * 3 dims:
* get_local_id(0) = (CSR_TID + vid.v) % (local_size_x * local_size_y) * get_local_id(0) = (CSR_TID + vid.v) % (local_size_x)
* get_local_id(1) = ((CSR_TID + vid.v) - get_local_id(2) * (local_size_x * local_size_y)) / local_size_x * get_local_id(1) = (CSR_TID + vid.v) %(local_size_x * local_size_y) / local_size_x
* get_local_id(2) = (CSR_TID + vid.v) / (local_size_x * local_size_y) * get_local_id(2) = (CSR_TID + vid.v) / (local_size_x * local_size_y)
* *
* *
* global_id (uniform methods in 1/2/3 dims): * global_id (uniform methods in 1/2/3 dims):
* get_global_id(0) = (CSR_GID_X - 1) * local_size_x + local_id_x * get_global_id(0) = CSR_GID_X * local_size_x + local_id_x
* get_global_id(1) = (CSR_GID_Y - 1) * local_size_y + local_id_y * get_global_id(1) = CSR_GID_Y * local_size_y + local_id_y
* get_global_id(2) = (CSR_GID_Z - 1) * local_size_z + local_id_z * get_global_id(2) = CSR_GID_Z * local_size_z + local_id_z
* *
* *
* global_linear_id: * global_linear_id:
@ -149,15 +149,7 @@ __builtin_riscv_workitem_id_x:
li t2, 1 li t2, 1
beq t0, t2, .WIXR beq t0, t2, .WIXR
lw t3, KNL_LC_SIZE_X(a0) # local_size_x lw t3, KNL_LC_SIZE_X(a0) # local_size_x
li t2, 3
beq t0, t2, .WIX3
.WIX2:
vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x
ret
.WIX3: # 3 dims
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y
mul t4, t4, t3 # local_size_x * local_size_y
vremu.vx v0, v0, t4 # local_id_x = local_liner_id % (local_size_x * local_size_y)
.WIXR: .WIXR:
ret ret
@ -205,14 +197,16 @@ __builtin_riscv_workitem_id_z:
.global __builtin_riscv_global_id_x .global __builtin_riscv_global_id_x
.type __builtin_riscv_global_id_x, @function .type __builtin_riscv_global_id_x, @function
__builtin_riscv_global_id_x: __builtin_riscv_global_id_x:
addi sp, sp, 4
sw ra, -4(sp)
call __builtin_riscv_workitem_id_x
csrr a0, CSR_KNL # Get kernel metadata buffer csrr a0, CSR_KNL # Get kernel metadata buffer
csrr t1, CSR_GID_X # Get group_id_x csrr t1, CSR_GID_X # Get group_id_x
csrr t2, CSR_TID
vid.v v2
vadd.vx v2, v2, t2 # workitem_id_x
lw t3, KNL_LC_SIZE_X(a0) # Get local_size_x lw t3, KNL_LC_SIZE_X(a0) # Get local_size_x
mul t3, t1, t3 # CSR_GID_X * local_size_x mul t6, t1, t3 # CSR_GID_X * local_size_x
vadd.vx v0, v2, t3 # global_id_x vadd.vx v0,v0, t6
lw ra, -4(sp)
addi sp, sp, -4
ret ret
@ -220,21 +214,26 @@ __builtin_riscv_global_id_x:
.global __builtin_riscv_global_id_y .global __builtin_riscv_global_id_y
.type __builtin_riscv_global_id_y, @function .type __builtin_riscv_global_id_y, @function
__builtin_riscv_global_id_y: __builtin_riscv_global_id_y:
csrr a0, CSR_KNL # Get kernel metadata buffer addi sp, sp, 4
sw ra, -4(sp)
call __builtin_riscv_workitem_id_y
csrr t1, CSR_GID_Y # Get group_id_y csrr t1, CSR_GID_Y # Get group_id_y
csrr t2, CSR_TID lw t2, KNL_LC_SIZE_Y(a0) # Get local_size_y
vid.v v2 mul t3, t1, t2 # CSR_GID_Y * local_size_y
vadd.vx v2, v2, t2 # workitem_id_y
lw t3, KNL_LC_SIZE_Y(a0) # Get local_size_y vadd.vx v0, v0, t3 # global_id_y
mul t3, t1, t3 # CSR_GID_Y * local_size_y lw ra, -4(sp)
vadd.vx v0, v2, t3 # global_id_y addi sp, sp, -4
ret ret
.text .text
.global __builtin_riscv_global_id_z .global __builtin_riscv_global_id_z
.type __builtin_riscv_global_id_z, @function .type __builtin_riscv_global_id_z, @function
__builtin_riscv_global_id_z: __builtin_riscv_global_id_z:
addi sp, sp, 4
sw ra, -4(sp)
call __builtin_riscv_workitem_id_z
csrr a0, CSR_KNL # Get kernel metadata buffer csrr a0, CSR_KNL # Get kernel metadata buffer
csrr t1, CSR_GID_Z # Get group_id_z csrr t1, CSR_GID_Z # Get group_id_z
csrr t2, CSR_TID csrr t2, CSR_TID
@ -242,7 +241,9 @@ __builtin_riscv_global_id_z:
vadd.vx v2, v2, t2 # workitem_id_z vadd.vx v2, v2, t2 # workitem_id_z
lw t3, KNL_LC_SIZE_Z(a0) # Get local_size_z lw t3, KNL_LC_SIZE_Z(a0) # Get local_size_z
mul t3, t1, t3 # CSR_GID_Z * local_size_z mul t3, t1, t3 # CSR_GID_Z * local_size_z
vadd.vx v0, v2, t3 # global_id_z vadd.vv v0, v2, v0 # global_id_z
lw ra, -4(sp)
addi sp, sp, -4
ret ret