[VENTUS][RISCV][fix] Fix workitem function implementation
This commit is contained in:
parent
b8223e72bd
commit
1a6ead3f43
|
@ -194,9 +194,6 @@ build_libclc() {
|
|||
if [ ! -d "${DstDir}" ]; then
|
||||
mkdir -p ${DstDir}
|
||||
fi
|
||||
# TODO: make this copy process done during libclc build process?
|
||||
cp ${LIBCLC_BUILD_DIR}/riscv32--.bc ${DstDir}/kernel-riscv32.bc
|
||||
cp ${LIBCLC_BUILD_DIR}/riscv32--.a ${VENTUS_INSTALL_PREFIX}/lib/libriscv32clc.a
|
||||
}
|
||||
|
||||
# Build icd_loader
|
||||
|
|
|
@ -17,15 +17,15 @@
|
|||
* get_local_id(1) = (CSR_TID + vid.v) / local_size_x
|
||||
*
|
||||
* 3 dims:
|
||||
* get_local_id(0) = (CSR_TID + vid.v) % (local_size_x * local_size_y)
|
||||
* get_local_id(1) = ((CSR_TID + vid.v) - get_local_id(2) * (local_size_x * local_size_y)) / local_size_x
|
||||
* get_local_id(0) = (CSR_TID + vid.v) % (local_size_x)
|
||||
* get_local_id(1) = (CSR_TID + vid.v) %(local_size_x * local_size_y) / local_size_x
|
||||
* get_local_id(2) = (CSR_TID + vid.v) / (local_size_x * local_size_y)
|
||||
*
|
||||
*
|
||||
* global_id (uniform methods in 1/2/3 dims):
|
||||
* get_global_id(0) = (CSR_GID_X - 1) * local_size_x + local_id_x
|
||||
* get_global_id(1) = (CSR_GID_Y - 1) * local_size_y + local_id_y
|
||||
* get_global_id(2) = (CSR_GID_Z - 1) * local_size_z + local_id_z
|
||||
* get_global_id(0) = CSR_GID_X * local_size_x + local_id_x
|
||||
* get_global_id(1) = CSR_GID_Y * local_size_y + local_id_y
|
||||
* get_global_id(2) = CSR_GID_Z * local_size_z + local_id_z
|
||||
*
|
||||
*
|
||||
* global_linear_id:
|
||||
|
@ -149,15 +149,7 @@ __builtin_riscv_workitem_id_x:
|
|||
li t2, 1
|
||||
beq t0, t2, .WIXR
|
||||
lw t3, KNL_LC_SIZE_X(a0) # local_size_x
|
||||
li t2, 3
|
||||
beq t0, t2, .WIX3
|
||||
.WIX2:
|
||||
vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x
|
||||
ret
|
||||
.WIX3: # 3 dims
|
||||
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y
|
||||
mul t4, t4, t3 # local_size_x * local_size_y
|
||||
vremu.vx v0, v0, t4 # local_id_x = local_liner_id % (local_size_x * local_size_y)
|
||||
.WIXR:
|
||||
ret
|
||||
|
||||
|
@ -205,14 +197,16 @@ __builtin_riscv_workitem_id_z:
|
|||
.global __builtin_riscv_global_id_x
|
||||
.type __builtin_riscv_global_id_x, @function
|
||||
__builtin_riscv_global_id_x:
|
||||
addi sp, sp, 4
|
||||
sw ra, -4(sp)
|
||||
call __builtin_riscv_workitem_id_x
|
||||
csrr a0, CSR_KNL # Get kernel metadata buffer
|
||||
csrr t1, CSR_GID_X # Get group_id_x
|
||||
csrr t2, CSR_TID
|
||||
vid.v v2
|
||||
vadd.vx v2, v2, t2 # workitem_id_x
|
||||
lw t3, KNL_LC_SIZE_X(a0) # Get local_size_x
|
||||
mul t3, t1, t3 # CSR_GID_X * local_size_x
|
||||
vadd.vx v0, v2, t3 # global_id_x
|
||||
mul t6, t1, t3 # CSR_GID_X * local_size_x
|
||||
vadd.vx v0,v0, t6
|
||||
lw ra, -4(sp)
|
||||
addi sp, sp, -4
|
||||
ret
|
||||
|
||||
|
||||
|
@ -220,21 +214,26 @@ __builtin_riscv_global_id_x:
|
|||
.global __builtin_riscv_global_id_y
|
||||
.type __builtin_riscv_global_id_y, @function
|
||||
__builtin_riscv_global_id_y:
|
||||
csrr a0, CSR_KNL # Get kernel metadata buffer
|
||||
addi sp, sp, 4
|
||||
sw ra, -4(sp)
|
||||
call __builtin_riscv_workitem_id_y
|
||||
csrr t1, CSR_GID_Y # Get group_id_y
|
||||
csrr t2, CSR_TID
|
||||
vid.v v2
|
||||
vadd.vx v2, v2, t2 # workitem_id_y
|
||||
lw t3, KNL_LC_SIZE_Y(a0) # Get local_size_y
|
||||
mul t3, t1, t3 # CSR_GID_Y * local_size_y
|
||||
vadd.vx v0, v2, t3 # global_id_y
|
||||
ret
|
||||
lw t2, KNL_LC_SIZE_Y(a0) # Get local_size_y
|
||||
mul t3, t1, t2 # CSR_GID_Y * local_size_y
|
||||
|
||||
vadd.vx v0, v0, t3 # global_id_y
|
||||
lw ra, -4(sp)
|
||||
addi sp, sp, -4
|
||||
ret
|
||||
|
||||
|
||||
.text
|
||||
.global __builtin_riscv_global_id_z
|
||||
.type __builtin_riscv_global_id_z, @function
|
||||
__builtin_riscv_global_id_z:
|
||||
addi sp, sp, 4
|
||||
sw ra, -4(sp)
|
||||
call __builtin_riscv_workitem_id_z
|
||||
csrr a0, CSR_KNL # Get kernel metadata buffer
|
||||
csrr t1, CSR_GID_Z # Get group_id_z
|
||||
csrr t2, CSR_TID
|
||||
|
@ -242,7 +241,9 @@ __builtin_riscv_global_id_z:
|
|||
vadd.vx v2, v2, t2 # workitem_id_z
|
||||
lw t3, KNL_LC_SIZE_Z(a0) # Get local_size_z
|
||||
mul t3, t1, t3 # CSR_GID_Z * local_size_z
|
||||
vadd.vx v0, v2, t3 # global_id_z
|
||||
vadd.vv v0, v2, v0 # global_id_z
|
||||
lw ra, -4(sp)
|
||||
addi sp, sp, -4
|
||||
ret
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue