From 6959b660221fe23fd4ac12cdffc94e308d72a559 Mon Sep 17 00:00:00 2001 From: Aries Date: Fri, 6 Jan 2023 09:32:07 +0800 Subject: [PATCH] Add CSRs and kernel metadata buffer offset constant definition ventus.h --- compiler-rt/lib/ventus/crt0.S | 18 ++---- compiler-rt/lib/ventus/ventus.h | 49 ++++++++++++++ libclc/riscv/lib/workitem/workitem.S | 97 ++++++++++++---------------- 3 files changed, 94 insertions(+), 70 deletions(-) create mode 100644 compiler-rt/lib/ventus/ventus.h diff --git a/compiler-rt/lib/ventus/crt0.S b/compiler-rt/lib/ventus/crt0.S index 6ac63d8cd20a..8df3d5d64c49 100644 --- a/compiler-rt/lib/ventus/crt0.S +++ b/compiler-rt/lib/ventus/crt0.S @@ -9,20 +9,12 @@ /** * crt0.S : Entry point for Ventus OpenCL C kernel programs * - * kernel metadata buffer: - * +-------4---------+----------4----------+-----4----+-------4-------+-------4------- - * | kernel func ptr | kernel arg base ptr | work_dim | global_size_x | global_size_y - * +-------4-------+------4-------+------4-------+------4-------+-------4--------- - * | global_size_z | local_size_x | local_size_y | local_size_z | global_offset_x - * +--------4--------+--------4--------+---- - * | global_offset_y | global_offset_z | ... - * - * * kernel arg buffer: * +-------+-------+--------+----- * | arg_0 | arg_1 | arg_2 | ... */ - + +#include "ventus.h" .text .global _start @@ -49,9 +41,9 @@ _start: bltu a0, a2, 1b 2: - csrr t0, CSR_KNL # get addr of kernel metadata - lw t1, 0(t0) # get kernel program address - lw a0, 4(t0) # get kernel arg buffer base address + csrr t0, CSR_KNL # get addr of kernel metadata + lw t1, KNL_ENTRY(t0) # get kernel program address + lw a0, KNL_ARG_BASE(t0) # get kernel arg buffer base address jalr t1 # call kernel program # call exit routine diff --git a/compiler-rt/lib/ventus/ventus.h b/compiler-rt/lib/ventus/ventus.h new file mode 100644 index 000000000000..a8baa25d91bd --- /dev/null +++ b/compiler-rt/lib/ventus/ventus.h @@ -0,0 +1,49 @@ +/** + * Copyright (c) 2023 Terapines Technology (Wuhan) Co., Ltd + * + * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + * See https://llvm.org/LICENSE.txt for license information. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + */ +/** + * This file defines the hardware CSR and kernel metadata buffer related const. + * + * kernel metadata buffer layout: + * +-------4---------+----------4----------+-----4----+-------4-------+-------4------- + * | kernel func ptr | kernel arg base ptr | work_dim | global_size_x | global_size_y + * +-------4-------+------4-------+------4-------+------4-------+-------4--------- + * | global_size_z | local_size_x | local_size_y | local_size_z | global_offset_x + * +--------4--------+--------4--------+---- + * | global_offset_y | global_offset_z | ... + */ + +#ifndef __VENTUS_H__ +#define __VENTUS_H__ + +// CSRs +#define CSR_TID 0x800 +#define CSR_NUMW 0x801 +#define CSR_NUMT 0x802 +#define CSR_KNL 0x803 // Kernel metadata buffer base address +#define CSR_WID 0x805 +#define CSR_LDS 0x806 +#define CSR_GDS 0x807 +#define CSR_GID_X 0x808 // group_id_x +#define CSR_GID_Y 0x809 +#define CSR_GID_Z 0x80a + +// Kernel metadata buffer offsets +#define KNL_ENTRY 0 +#define KNL_ARG_BASE 4 +#define KNL_WORK_DIM 8 +#define KNL_GL_SIZE_X 12 +#define KNL_GL_SIZE_Y 16 +#define KNL_GL_SIZE_Z 20 +#define KNL_LC_SIZE_X 24 +#define KNL_LC_SIZE_Y 28 +#define KNL_LC_SIZE_Z 32 +#define KNL_GL_OFFSET_X 36 +#define KNL_GL_OFFSET_Y 40 +#define KNL_GL_OFFSET_Z 44 + +#endif // __VENTUS_H__ diff --git a/libclc/riscv/lib/workitem/workitem.S b/libclc/riscv/lib/workitem/workitem.S index a384d1203067..61f44f9c9551 100644 --- a/libclc/riscv/lib/workitem/workitem.S +++ b/libclc/riscv/lib/workitem/workitem.S @@ -6,7 +6,7 @@ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception */ /** - * See crt0.S for kernel metadata buffer detailed layout. + * See ventus.h for kernel metadata buffer detailed layout. * * workitem_id: * 1 dim: @@ -40,6 +40,7 @@ * */ +#include "ventus.h" .text .global __builtin_riscv_workitem_linear_id @@ -57,17 +58,17 @@ __builtin_riscv_workitem_linear_id: .global __builtin_riscv_global_linear_id .type __builtin_riscv_global_linear_id, @function __builtin_riscv_global_linear_id: - csrr a3, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a3) # Get work_dims + csrr a3, CSR_KNL # Get kernel metadata buffer + lw t0, KNL_WORK_DIM(a3) # Get work_dims call __builtin_riscv_global_id_x - lw t4, 36(a3) # Get global_offset_x - vsub.vx v5, v0, t4 # global_linear_id1 + lw t4, KNL_GL_OFFSET_X(a3) # global_offset_x + vsub.vx v5, v0, t4 # global_linear_id1 li t5, 1 beq t0, t5, .GLR # Return global_linear_id for 1 dim .GL_2DIM: call __builtin_riscv_global_id_y - lw t6, 12(a3) # global_size_x - lw t5, 40(a3) # global_offset_y + lw t6, KL_GL_SIZE_X(a3) # global_size_x + lw t5, KL_GL_SIZE_Y(a3) # global_offset_y vsub.vx v6, v0, t5 # tmp = global_id_y - global_offset_y vmul.vx v6, v6, t6 # tmp = tmp * global_size_x vadd.vv v5, v5, v6 # global_linear_id2 = tmp + global_linear_id1 @@ -75,9 +76,9 @@ __builtin_riscv_global_linear_id: beq t0, t5, .GLR # Return global_linear_id for 2 dim .GL_3DIM: call __builtin_riscv_global_id_z - lw t6, 12(a3) # global_size_x - lw t7, 16(a3) # global_size_y - lw t5, 44(a3) # global_offset_z + lw t6, KL_GL_SIZE_X(a3) # global_size_x + lw t7, KL_GL_SIZE_Y(a3) # global_size_y + lw t5, KL_GL_OFFSET_Z(a3) # global_offset_z vsub.vx v6, v0, t5 # tmp = global_id_z - global_offset_z vmul.vx v6, v6, t6 # tmp = tmp * global_size_x vmul.vx v6, v6, t7 # tmp = tmp * global_size_y @@ -123,22 +124,22 @@ __builtin_riscv_workgroup_id_z: .type __builtin_riscv_workitem_id_x, @function __builtin_riscv_workitem_id_x: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a0) # Get work_dim + lw t0, KNL_WORK_DIM(a0) # Get work_dim csrr t1, CSR_TID # tid base offset for current warp vid.v v2 # current thread offset vadd.vx v0, v2, t1 # local_id_x in 1 dim (local_linear_id) li t2, 1 - beq t0, t2, .WIXR # 1 dim - lw t3, 24(a0) # local_size_x + beq t0, t2, .WIXR + lw t3, KNL_LC_SIZE_X(a0) # local_size_x li t2, 3 beq t0, t2, .WIX3 .WIX2: - vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x + vremu.vx v0, v0, t3 # local_id_x = local_liner_id % local_size_x ret -.WIX3: # 3 dims - lw t4, 28(a0) # local_size_y - mul t4, t4, t3 # local_size_x * local_size_y - vremu.vx v0, v0, t4 # local_id_x = local_liner_id % (local_size_x * local_size_y) +.WIX3: # 3 dims + lw t4, KNL_LC_SIZE_Y(a0) # local_size_y + mul t4, t4, t3 # local_size_x * local_size_y + vremu.vx v0, v0, t4 # local_id_x = local_liner_id % (local_size_x * local_size_y) .WIXR: ret .size __builtin_riscv_workitem_id_x .- __builtin_riscv_workitem_id_x @@ -149,18 +150,18 @@ __builtin_riscv_workitem_id_x: .type __builtin_riscv_workitem_id_y, @function __builtin_riscv_workitem_id_y: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a0) # Get work_dim + lw t0, KNL_WORK_DIM(a0) # Get work_dim csrr t1, CSR_TID # tid base offset for current warp vid.v v2 # current thread offset vadd.vx v0, v2, t1 # local_linear_id - lw t3, 24(a0) # local_size_x offset in 2 work_dims + lw t3, KNL_LC_SIZE_X(a0) # local_size_x offset in 2 work_dims li t2, 3 beq t0, t2, .WIY3 .WIY2: # 2 dims vdivu.vx v0, v0, t3 # local_id_y = local_liner_id / local_size_x ret .WIY3: # 3 dims - lw t4, 28(a0) # local_size_y + lw t4, KNL_LC_SIZE_Y(a0) # local_size_y mul t4, t4, t3 # local_size_x * local_size_y vremu.vx v0, v0, t4 # x = local_linear_id % (local_size_x * local_size_y) vdivu.ux v0, v0, t3 # x / local_size_x @@ -176,8 +177,8 @@ __builtin_riscv_workitem_id_z: csrr t1, CSR_TID # tid base offset for current warp vid.v v2 # current thread offset vadd.vx v0, v2, t1 # local_linear_id - lw t3, 24(a0) # local_size_x - lw t4, 28(a0) # local_size_y + lw t3, KNL_LC_SIZE_X(a0) # local_size_x + lw t4, KNL_LC_SIZE_Y(a0) # local_size_y mul t4, t4, t3 # local_size_x * local_size_y vdivu.vx v0, v0, t4 # local_linear_id / (local_size_x * local_size_y) 7: @@ -195,7 +196,7 @@ __builtin_riscv_global_id_x: csrr t2, CSR_TID vid.v v2 vadd.vx v2, v2, t2 # workitem_id_x - lw t3, 24(a0) # Get local_size_x + lw t3, KNL_LC_SIZE_X(a0) # Get local_size_x mul t3, t1, t3 # (CSR_GID_X - 1) * local_size_x vadd.vx v0, v2, t3 # global_id_x ret @@ -212,7 +213,7 @@ __builtin_riscv_global_id_y: csrr t2, CSR_TID vid.v v2 vadd.vx v2, v2, t2 # workitem_id_y - lw t3, 28(a0) # Get local_size_y + lw t3, KNL_LC_SIZE_Y(a0) # Get local_size_y mul t3, t1, t3 # (CSR_GID_Y - 1) * local_size_y vadd.vx v0, v2, t3 # global_id_y ret @@ -229,7 +230,7 @@ __builtin_riscv_global_id_z: csrr t2, CSR_TID vid.v v2 vadd.vx v2, v2, t2 # workitem_id_z - lw t3, 32(a0) # Get local_size_z + lw t3, KNL_LC_SIZE_Z(a0) # Get local_size_z mul t3, t1, t3 # (CSR_GID_Z - 1) * local_size_z vadd.vx v0, v2, t3 # global_id_z ret @@ -241,11 +242,7 @@ __builtin_riscv_global_id_z: .type __builtin_riscv_local_size_x, @function __builtin_riscv_local_size_x: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a0) # Get work_dim - addi t2, zero, 4 - mul t0, t0, t2 # Skip offset of global_size_xyz slots - add a0, a0, t0 - lw t0, 0(a0) # load local_size_x + lw t0, KNL_LC_SIZE_X(a0) # Load local_size_x vmv.s.x v0, t0 ret .size __builtin_riscv_local_size_x, .-__builtin_riscv_local_size_x @@ -256,10 +253,7 @@ __builtin_riscv_local_size_x: .type __builtin_riscv_local_size_y, @function __builtin_riscv_local_size_y: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a0) # Get work_dim - addi t2, zero, 4 - mul t0, t0, t2 # Skip offset of global_size_xyz slots - lw t0, 4(a0) # Load local_size_y + lw t0, KNL_LC_SIZE_Y(a0) # Load local_size_y vmv.s.x v0, t0 ret .size __builtin_riscv_local_size_y, .-__builtin_riscv_local_size_y @@ -270,10 +264,7 @@ __builtin_riscv_local_size_y: .type __builtin_riscv_local_size_z, @function __builtin_riscv_local_size_z: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a0) # Get work_dim - addi t2, zero, 4 - mul t0, t0, t2 # Skip offset of global_size_xyz slots - lw t0, 8(a0) # Load local_size_x + lw t0, KNL_LC_SIZE_Z(a0) # Load local_size_z vmv.s.x v0, t0 ret .size __builtin_riscv_local_size_z, .-__builtin_riscv_local_size_z @@ -284,7 +275,7 @@ __builtin_riscv_local_size_z: .type __builtin_riscv_global_size_x, @function __builtin_riscv_global_size_x: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 12(a0) # Get global_size_x + lw t0, KNL_GL_SIZE_X(a0) # Get global_size_x vmv.s.x v0, t0 ret .size __builtin_riscv_global_size_x, .-__builtin_riscv_global_size_x @@ -295,7 +286,7 @@ __builtin_riscv_global_size_x: .type __builtin_riscv_global_size_y, @function __builtin_riscv_global_size_y: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 16(a0) # Get global_size_y + lw t0, KNL_GL_SIZE_Y(a0) # Get global_size_y vmv.s.x v0, t0 ret .size __builtin_riscv_global_size_y, .-__builtin_riscv_global_size_y @@ -306,7 +297,7 @@ __builtin_riscv_global_size_y: .type __builtin_riscv_global_size_z, @function __builtin_riscv_global_size_z: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 20(a0) # Get global_size_z + lw t0, KNL_GL_SIZE_Z(a0) # Get global_size_z vmv.s.x v0, t0 ret .size __builtin_riscv_global_size_z, .-__builtin_riscv_global_size_z @@ -317,12 +308,8 @@ __builtin_riscv_global_size_z: .type __builtin_riscv_num_groups_x, @function __builtin_riscv_num_groups_x: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a0) # Get work_dim - lw t1, 12(a0) # Get global_size_x - addi t2, zero, 4 # 4 bytes per slot - mul t0, t0, t2 - add a0, a0, t0 # Skip offset of global_size_xyz slots - lw t0, 0(a0) # Get local_size_x + lw t1, KNL_GL_SIZE_X(a0) # Get global_size_x + lw t0, KNL_LC_SIZE_X(a0) # Get local_size_x divu t1, t1, t0 # global_size_x / local_size_x vmv.s.x v0, t1 ret @@ -334,12 +321,8 @@ __builtin_riscv_num_groups_x: .type __builtin_riscv_num_groups_y, @function __builtin_riscv_num_groups_y: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a0) # Get work_dim - lw t1, 16(a0) # Get global_size_y - addi t2, zero, 4 # 4 bytes per slot - mul t0, t0, t2 - add a0, a0, t0 # Skip offset of global_size_xyz slots - lw t0, 4(a0) # Get local_size_y + lw t1, KNL_GL_SIZE_Y(a0) # Get global_size_y + lw t0, KNL_LC_SIZE_Y(a0) # Get local_size_y divu t1, t1, t0 # global_size_y / local_size_y vmv.s.x v0, t1 ret @@ -351,8 +334,8 @@ __builtin_riscv_num_groups_y: .type __builtin_riscv_num_groups_z, @function __builtin_riscv_num_groups_z: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t1, 20(a0) # Get global_size_z - lw t2, 32(a0) # Get local_size_z + lw t1, KNL_GL_SIZE_Z(a0) # Get global_size_z + lw t2, KNL_LC_SIZE_Z(a0) # Get local_size_z divu t1, t1, t2 # global_size_z / local_size_z vmv.s.x v0, t1 ret @@ -364,7 +347,7 @@ __builtin_riscv_num_groups_z: .type __builtin_riscv_work_dim, @function __builtin_riscv_work_dim: csrr a0, CSR_KNL # Get kernel metadata buffer - lw t0, 8(a0) # Get work_dim + lw t0, KNL_WORK_DIM(a0) # Get work_dim vmv.s.x v0, t0 ret .size __builtin_riscv_work_dim, .-__builtin_riscv_work_dim