Merge pull request #98 from THU-DSP-LAB/56_workitem_function_fix

[VENTUS][fix] Fix workitem function implementation bug
This commit is contained in:
zhoujingya 2024-02-01 15:05:25 +08:00 committed by GitHub
commit b32b529523
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 74 additions and 33 deletions

View File

@ -8,6 +8,7 @@ workitem/get_local_linear_id.cl
workitem/get_local_size.cl
workitem/get_num_groups.cl
workitem/get_work_dim.cl
workitem/get_enqueued_local_size.cl
compiler-rt/nextafterf.cl
compiler-rt/adddf3.cl

View File

@ -0,0 +1,7 @@
#include <clc/clc.h>
// get_global_size(unit dim) / get_num_groups(unit dim)
_CLC_DEF _CLC_OVERLOAD size_t get_enqueued_local_size(uint dim) {
return get_local_size(dim);
}

View File

@ -1,7 +1,20 @@
#include <clc/clc.h>
extern size_t __builtin_riscv_global_linear_id();
_CLC_DEF _CLC_OVERLOAD size_t get_global_linear_id() {
return __builtin_riscv_global_linear_id();
uint dim = get_work_dim() - 1;
switch (dim) {
case 0:
return get_global_id(0) - get_global_offset(0);
;
case 1:
return (get_global_id(1) - get_global_offset(1)) * get_global_size(0) +
(get_global_id(0) - get_global_offset(0));
case 2:
return ((get_global_id(2) - get_global_offset(2)) * get_global_size(1) +
(get_global_id(1) - get_global_offset(1))) *
get_global_size(0) +
(get_global_id(0) - get_global_offset(0));
default:
return 0;
}
}

View File

@ -1,7 +1,17 @@
#include <clc/clc.h>
extern size_t __builtin_riscv_workitem_linear_id();
_CLC_DEF _CLC_OVERLOAD size_t get_local_linear_id() {
return __builtin_riscv_workitem_linear_id();
uint dim = get_work_dim() - 1;
switch (dim) {
case 0:
return get_local_id(0);
case 1:
return get_local_id(1) * get_local_size(0) + get_local_id(0);
case 2:
return (get_local_id(2) * get_local_size(1) + get_local_id(1)) *
get_local_size(0) +
get_local_id(0);
default:
return 0;
}
}

View File

@ -23,9 +23,9 @@
*
*
* global_id (uniform methods in 1/2/3 dims):
* get_global_id(0) = CSR_GID_X * local_size_x + local_id_x
* get_global_id(1) = CSR_GID_Y * local_size_y + local_id_y
* get_global_id(2) = CSR_GID_Z * local_size_z + local_id_z
* get_global_id(0) = _global_offset_x + CSR_GID_X * local_size_x + local_id_x
* get_global_id(1) = _global_offset_y + CSR_GID_Y * local_size_y + local_id_y
* get_global_id(2) = _global_offset_z + CSR_GID_Z * local_size_z + local_id_z
*
*
* global_linear_id:
@ -65,17 +65,6 @@ _local_id_z:
.word 0
// End workaround for pocl driver
.text
.global __builtin_riscv_workitem_linear_id
.type __builtin_riscv_workitem_linear_id, @function
__builtin_riscv_workitem_linear_id:
csrr a0, CSR_KNL # Get kernel metadata buffer
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_linear_id
ret
.text
.global __builtin_riscv_global_linear_id
.type __builtin_riscv_global_linear_id, @function
@ -145,16 +134,17 @@ __builtin_riscv_workgroup_id_z:
.global __builtin_riscv_workitem_id_x
.type __builtin_riscv_workitem_id_x, @function
__builtin_riscv_workitem_id_x:
addi sp, sp, 4
sw ra, -4(sp)
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, KNL_WORK_DIM(a0) # Get work_dim
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_id_x in 1 dim (local_linear_id)
li t2, 1
beq t0, t2, .WIXR
lw t3, KNL_LC_SIZE_X(a0) # local_size_x
vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x
.WIXR:
lw ra, -4(sp)
addi sp, sp, -4
ret
@ -162,22 +152,29 @@ __builtin_riscv_workitem_id_x:
.global __builtin_riscv_workitem_id_y
.type __builtin_riscv_workitem_id_y, @function
__builtin_riscv_workitem_id_y:
addi sp, sp, 4
sw ra, -4(sp)
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, KNL_WORK_DIM(a0) # Get work_dim
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_linear_id
lw t3, KNL_LC_SIZE_X(a0) # local_size_x offset in 2 work_dims
li t2, 3
beq t0, t2, .WIY3
.WIY2: # 2 dims
vdivu.vx v0, v0, t3 # local_id_y = local_liner_id / local_size_x
ret
.WIY3: # 3 dims
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y
mul t4, t4, t3 # local_size_x * local_size_y
vremu.vx v0, v0, t4 # x = local_linear_id % (local_size_x * local_size_y)
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y offset in 2 work_dims
mul t5, t4, t3 # local_size_x * local_size_y
vremu.vx v0, v0, t5 # x = local_linear_id % (local_size_x * local_size_y)
vdivu.vx v0, v0, t3 # x / local_size_x
vmv.v.x v1, t4
.hi2:
auipc t1, %pcrel_hi(.end2)
setrpc zero, t1, %pcrel_lo(.hi2)
vblt v0, v1, .end2
li t5, -1
vadd.vx v0, v1, t5
.end2:
join zero, zero, 0
lw ra, -4(sp)
addi sp, sp, -4
ret
@ -185,15 +182,28 @@ __builtin_riscv_workitem_id_y:
.global __builtin_riscv_workitem_id_z
.type __builtin_riscv_workitem_id_z, @function
__builtin_riscv_workitem_id_z:
addi sp, sp, 4
sw ra, -4(sp)
csrr a0, CSR_KNL # Get kernel metadata buffer
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_linear_id
lw t3, KNL_LC_SIZE_X(a0) # local_size_x
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y
lw t5, KNL_LC_SIZE_Z(a0) # local_size_y
mul t4, t4, t3 # local_size_x * local_size_y
vdivu.vx v0, v0, t4 # local_linear_id / (local_size_x * local_size_y)
7:
vmv.v.x v1, t5
.hi3:
auipc t1, %pcrel_hi(.end3)
setrpc zero, t1, %pcrel_lo(.hi3)
vblt v0, v1, .end3
li t5, -1
vadd.vx v0, v1, t5
.end3:
join zero, zero, 0
lw ra, -4(sp)
addi sp, sp, -4
ret