Add CSRs and kernel metadata buffer offset constant definition ventus.h

This commit is contained in:
Aries 2023-01-06 09:32:07 +08:00
parent 2e4e32e87c
commit 6959b66022
3 changed files with 94 additions and 70 deletions

View File

@ -9,20 +9,12 @@
/**
* crt0.S : Entry point for Ventus OpenCL C kernel programs
*
* kernel metadata buffer:
* +-------4---------+----------4----------+-----4----+-------4-------+-------4-------
* | kernel func ptr | kernel arg base ptr | work_dim | global_size_x | global_size_y
* +-------4-------+------4-------+------4-------+------4-------+-------4---------
* | global_size_z | local_size_x | local_size_y | local_size_z | global_offset_x
* +--------4--------+--------4--------+----
* | global_offset_y | global_offset_z | ...
*
*
* kernel arg buffer:
* +-------+-------+--------+-----
* | arg_0 | arg_1 | arg_2 | ...
*/
#include "ventus.h"
.text
.global _start
@ -49,9 +41,9 @@ _start:
bltu a0, a2, 1b
2:
csrr t0, CSR_KNL # get addr of kernel metadata
lw t1, 0(t0) # get kernel program address
lw a0, 4(t0) # get kernel arg buffer base address
csrr t0, CSR_KNL # get addr of kernel metadata
lw t1, KNL_ENTRY(t0) # get kernel program address
lw a0, KNL_ARG_BASE(t0) # get kernel arg buffer base address
jalr t1 # call kernel program
# call exit routine

View File

@ -0,0 +1,49 @@
/**
* Copyright (c) 2023 Terapines Technology (Wuhan) Co., Ltd
*
* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
* See https://llvm.org/LICENSE.txt for license information.
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/
/**
* This file defines the hardware CSR and kernel metadata buffer related const.
*
* kernel metadata buffer layout:
* +-------4---------+----------4----------+-----4----+-------4-------+-------4-------
* | kernel func ptr | kernel arg base ptr | work_dim | global_size_x | global_size_y
* +-------4-------+------4-------+------4-------+------4-------+-------4---------
* | global_size_z | local_size_x | local_size_y | local_size_z | global_offset_x
* +--------4--------+--------4--------+----
* | global_offset_y | global_offset_z | ...
*/
#ifndef __VENTUS_H__
#define __VENTUS_H__
// CSRs
#define CSR_TID 0x800
#define CSR_NUMW 0x801
#define CSR_NUMT 0x802
#define CSR_KNL 0x803 // Kernel metadata buffer base address
#define CSR_WID 0x805
#define CSR_LDS 0x806
#define CSR_GDS 0x807
#define CSR_GID_X 0x808 // group_id_x
#define CSR_GID_Y 0x809
#define CSR_GID_Z 0x80a
// Kernel metadata buffer offsets
#define KNL_ENTRY 0
#define KNL_ARG_BASE 4
#define KNL_WORK_DIM 8
#define KNL_GL_SIZE_X 12
#define KNL_GL_SIZE_Y 16
#define KNL_GL_SIZE_Z 20
#define KNL_LC_SIZE_X 24
#define KNL_LC_SIZE_Y 28
#define KNL_LC_SIZE_Z 32
#define KNL_GL_OFFSET_X 36
#define KNL_GL_OFFSET_Y 40
#define KNL_GL_OFFSET_Z 44
#endif // __VENTUS_H__

View File

@ -6,7 +6,7 @@
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/
/**
* See crt0.S for kernel metadata buffer detailed layout.
* See ventus.h for kernel metadata buffer detailed layout.
*
* workitem_id:
* 1 dim:
@ -40,6 +40,7 @@
*
*/
#include "ventus.h"
.text
.global __builtin_riscv_workitem_linear_id
@ -57,17 +58,17 @@ __builtin_riscv_workitem_linear_id:
.global __builtin_riscv_global_linear_id
.type __builtin_riscv_global_linear_id, @function
__builtin_riscv_global_linear_id:
csrr a3, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a3) # Get work_dims
csrr a3, CSR_KNL # Get kernel metadata buffer
lw t0, KNL_WORK_DIM(a3) # Get work_dims
call __builtin_riscv_global_id_x
lw t4, 36(a3) # Get global_offset_x
vsub.vx v5, v0, t4 # global_linear_id1
lw t4, KNL_GL_OFFSET_X(a3) # global_offset_x
vsub.vx v5, v0, t4 # global_linear_id1
li t5, 1
beq t0, t5, .GLR # Return global_linear_id for 1 dim
.GL_2DIM:
call __builtin_riscv_global_id_y
lw t6, 12(a3) # global_size_x
lw t5, 40(a3) # global_offset_y
lw t6, KL_GL_SIZE_X(a3) # global_size_x
lw t5, KL_GL_SIZE_Y(a3) # global_offset_y
vsub.vx v6, v0, t5 # tmp = global_id_y - global_offset_y
vmul.vx v6, v6, t6 # tmp = tmp * global_size_x
vadd.vv v5, v5, v6 # global_linear_id2 = tmp + global_linear_id1
@ -75,9 +76,9 @@ __builtin_riscv_global_linear_id:
beq t0, t5, .GLR # Return global_linear_id for 2 dim
.GL_3DIM:
call __builtin_riscv_global_id_z
lw t6, 12(a3) # global_size_x
lw t7, 16(a3) # global_size_y
lw t5, 44(a3) # global_offset_z
lw t6, KL_GL_SIZE_X(a3) # global_size_x
lw t7, KL_GL_SIZE_Y(a3) # global_size_y
lw t5, KL_GL_OFFSET_Z(a3) # global_offset_z
vsub.vx v6, v0, t5 # tmp = global_id_z - global_offset_z
vmul.vx v6, v6, t6 # tmp = tmp * global_size_x
vmul.vx v6, v6, t7 # tmp = tmp * global_size_y
@ -123,22 +124,22 @@ __builtin_riscv_workgroup_id_z:
.type __builtin_riscv_workitem_id_x, @function
__builtin_riscv_workitem_id_x:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a0) # Get work_dim
lw t0, KNL_WORK_DIM(a0) # Get work_dim
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_id_x in 1 dim (local_linear_id)
li t2, 1
beq t0, t2, .WIXR # 1 dim
lw t3, 24(a0) # local_size_x
beq t0, t2, .WIXR
lw t3, KNL_LC_SIZE_X(a0) # local_size_x
li t2, 3
beq t0, t2, .WIX3
.WIX2:
vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x
vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x
ret
.WIX3: # 3 dims
lw t4, 28(a0) # local_size_y
mul t4, t4, t3 # local_size_x * local_size_y
vremu.vx v0, v0, t4 # local_id_x = local_liner_id % (local_size_x * local_size_y)
.WIX3: # 3 dims
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y
mul t4, t4, t3 # local_size_x * local_size_y
vremu.vx v0, v0, t4 # local_id_x = local_liner_id % (local_size_x * local_size_y)
.WIXR:
ret
.size __builtin_riscv_workitem_id_x .- __builtin_riscv_workitem_id_x
@ -149,18 +150,18 @@ __builtin_riscv_workitem_id_x:
.type __builtin_riscv_workitem_id_y, @function
__builtin_riscv_workitem_id_y:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a0) # Get work_dim
lw t0, KNL_WORK_DIM(a0) # Get work_dim
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_linear_id
lw t3, 24(a0) # local_size_x offset in 2 work_dims
lw t3, KNL_LC_SIZE_X(a0) # local_size_x offset in 2 work_dims
li t2, 3
beq t0, t2, .WIY3
.WIY2: # 2 dims
vdivu.vx v0, v0, t3 # local_id_y = local_liner_id / local_size_x
ret
.WIY3: # 3 dims
lw t4, 28(a0) # local_size_y
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y
mul t4, t4, t3 # local_size_x * local_size_y
vremu.vx v0, v0, t4 # x = local_linear_id % (local_size_x * local_size_y)
vdivu.ux v0, v0, t3 # x / local_size_x
@ -176,8 +177,8 @@ __builtin_riscv_workitem_id_z:
csrr t1, CSR_TID # tid base offset for current warp
vid.v v2 # current thread offset
vadd.vx v0, v2, t1 # local_linear_id
lw t3, 24(a0) # local_size_x
lw t4, 28(a0) # local_size_y
lw t3, KNL_LC_SIZE_X(a0) # local_size_x
lw t4, KNL_LC_SIZE_Y(a0) # local_size_y
mul t4, t4, t3 # local_size_x * local_size_y
vdivu.vx v0, v0, t4 # local_linear_id / (local_size_x * local_size_y)
7:
@ -195,7 +196,7 @@ __builtin_riscv_global_id_x:
csrr t2, CSR_TID
vid.v v2
vadd.vx v2, v2, t2 # workitem_id_x
lw t3, 24(a0) # Get local_size_x
lw t3, KNL_LC_SIZE_X(a0) # Get local_size_x
mul t3, t1, t3 # (CSR_GID_X - 1) * local_size_x
vadd.vx v0, v2, t3 # global_id_x
ret
@ -212,7 +213,7 @@ __builtin_riscv_global_id_y:
csrr t2, CSR_TID
vid.v v2
vadd.vx v2, v2, t2 # workitem_id_y
lw t3, 28(a0) # Get local_size_y
lw t3, KNL_LC_SIZE_Y(a0) # Get local_size_y
mul t3, t1, t3 # (CSR_GID_Y - 1) * local_size_y
vadd.vx v0, v2, t3 # global_id_y
ret
@ -229,7 +230,7 @@ __builtin_riscv_global_id_z:
csrr t2, CSR_TID
vid.v v2
vadd.vx v2, v2, t2 # workitem_id_z
lw t3, 32(a0) # Get local_size_z
lw t3, KNL_LC_SIZE_Z(a0) # Get local_size_z
mul t3, t1, t3 # (CSR_GID_Z - 1) * local_size_z
vadd.vx v0, v2, t3 # global_id_z
ret
@ -241,11 +242,7 @@ __builtin_riscv_global_id_z:
.type __builtin_riscv_local_size_x, @function
__builtin_riscv_local_size_x:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a0) # Get work_dim
addi t2, zero, 4
mul t0, t0, t2 # Skip offset of global_size_xyz slots
add a0, a0, t0
lw t0, 0(a0) # load local_size_x
lw t0, KNL_LC_SIZE_X(a0) # Load local_size_x
vmv.s.x v0, t0
ret
.size __builtin_riscv_local_size_x, .-__builtin_riscv_local_size_x
@ -256,10 +253,7 @@ __builtin_riscv_local_size_x:
.type __builtin_riscv_local_size_y, @function
__builtin_riscv_local_size_y:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a0) # Get work_dim
addi t2, zero, 4
mul t0, t0, t2 # Skip offset of global_size_xyz slots
lw t0, 4(a0) # Load local_size_y
lw t0, KNL_LC_SIZE_Y(a0) # Load local_size_y
vmv.s.x v0, t0
ret
.size __builtin_riscv_local_size_y, .-__builtin_riscv_local_size_y
@ -270,10 +264,7 @@ __builtin_riscv_local_size_y:
.type __builtin_riscv_local_size_z, @function
__builtin_riscv_local_size_z:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a0) # Get work_dim
addi t2, zero, 4
mul t0, t0, t2 # Skip offset of global_size_xyz slots
lw t0, 8(a0) # Load local_size_x
lw t0, KNL_LC_SIZE_Z(a0) # Load local_size_z
vmv.s.x v0, t0
ret
.size __builtin_riscv_local_size_z, .-__builtin_riscv_local_size_z
@ -284,7 +275,7 @@ __builtin_riscv_local_size_z:
.type __builtin_riscv_global_size_x, @function
__builtin_riscv_global_size_x:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 12(a0) # Get global_size_x
lw t0, KNL_GL_SIZE_X(a0) # Get global_size_x
vmv.s.x v0, t0
ret
.size __builtin_riscv_global_size_x, .-__builtin_riscv_global_size_x
@ -295,7 +286,7 @@ __builtin_riscv_global_size_x:
.type __builtin_riscv_global_size_y, @function
__builtin_riscv_global_size_y:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 16(a0) # Get global_size_y
lw t0, KNL_GL_SIZE_Y(a0) # Get global_size_y
vmv.s.x v0, t0
ret
.size __builtin_riscv_global_size_y, .-__builtin_riscv_global_size_y
@ -306,7 +297,7 @@ __builtin_riscv_global_size_y:
.type __builtin_riscv_global_size_z, @function
__builtin_riscv_global_size_z:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 20(a0) # Get global_size_z
lw t0, KNL_GL_SIZE_Z(a0) # Get global_size_z
vmv.s.x v0, t0
ret
.size __builtin_riscv_global_size_z, .-__builtin_riscv_global_size_z
@ -317,12 +308,8 @@ __builtin_riscv_global_size_z:
.type __builtin_riscv_num_groups_x, @function
__builtin_riscv_num_groups_x:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a0) # Get work_dim
lw t1, 12(a0) # Get global_size_x
addi t2, zero, 4 # 4 bytes per slot
mul t0, t0, t2
add a0, a0, t0 # Skip offset of global_size_xyz slots
lw t0, 0(a0) # Get local_size_x
lw t1, KNL_GL_SIZE_X(a0) # Get global_size_x
lw t0, KNL_LC_SIZE_X(a0) # Get local_size_x
divu t1, t1, t0 # global_size_x / local_size_x
vmv.s.x v0, t1
ret
@ -334,12 +321,8 @@ __builtin_riscv_num_groups_x:
.type __builtin_riscv_num_groups_y, @function
__builtin_riscv_num_groups_y:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a0) # Get work_dim
lw t1, 16(a0) # Get global_size_y
addi t2, zero, 4 # 4 bytes per slot
mul t0, t0, t2
add a0, a0, t0 # Skip offset of global_size_xyz slots
lw t0, 4(a0) # Get local_size_y
lw t1, KNL_GL_SIZE_Y(a0) # Get global_size_y
lw t0, KNL_LC_SIZE_Y(a0) # Get local_size_y
divu t1, t1, t0 # global_size_y / local_size_y
vmv.s.x v0, t1
ret
@ -351,8 +334,8 @@ __builtin_riscv_num_groups_y:
.type __builtin_riscv_num_groups_z, @function
__builtin_riscv_num_groups_z:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t1, 20(a0) # Get global_size_z
lw t2, 32(a0) # Get local_size_z
lw t1, KNL_GL_SIZE_Z(a0) # Get global_size_z
lw t2, KNL_LC_SIZE_Z(a0) # Get local_size_z
divu t1, t1, t2 # global_size_z / local_size_z
vmv.s.x v0, t1
ret
@ -364,7 +347,7 @@ __builtin_riscv_num_groups_z:
.type __builtin_riscv_work_dim, @function
__builtin_riscv_work_dim:
csrr a0, CSR_KNL # Get kernel metadata buffer
lw t0, 8(a0) # Get work_dim
lw t0, KNL_WORK_DIM(a0) # Get work_dim
vmv.s.x v0, t0
ret
.size __builtin_riscv_work_dim, .-__builtin_riscv_work_dim