Files
c_soft/cpu/riscv.c
2025-04-16 19:39:04 +08:00

501 lines
12 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "riscv.h"
/*
每个指令长度为32bit分为6个类型
|--------------+------------+-----------+-------------+--------------+---------------|
| [31-25] | [24-20] | [19-15] | [14-12] | [11-7] | [6-0] |
|--------------+------------+-----------+-------------+--------------+---------------|
| funct7 | rs2 | rs1 | funct3 | rd | opcode | R-type 计算指令
|--------------+------------+-----------+-------------+--------------+---------------|
| imm[11:0] | rs1 | funct3 | rd | opcode | I-type 加载指令
|--------------+------------+-----------+-------------+--------------+---------------|
| imm[11:5] | rs2 | rs1 | funct3 | imm[4:0] | opcode | S-type 存储指令
|--------------+------------+-----------+-------------+--------------+---------------|
|imm[12] [10:5]| rs2 | rs1 | funct3 | imm[4:1][11] | opcode | B-type 分支指令
|--------------+------------+-----------+-------------+--------------+---------------|
| imm[31:12] | rd | opcode | U-type 立即数指令
|-----------------------------------------------------+--------------+---------------|
| imm[20] [10:1] [11] [19:12] | rd | opcode | J-type 跳转指令
|-----------------------------------------------------+--------------+---------------|
*/
// 定义opcode
#define opcode_lui 0x37 // U-type
#define opcode_auipc 0x17 // U-type
#define opcode_jal 0x6f // J-type
#define opcode_jalr 0x67 // I-type
// B-type
#define opcode_beq 0x63
#define opcode_bne 0x63
#define opcode_blt 0x63
#define opcode_bge 0x63
#define opcode_bltu 0x63
#define opcode_bgeu 0x63
// I-type
#define opcode_lb 0x03
#define opcode_lh 0x03
#define opcode_lw 0x03
#define opcode_lbu 0x03
#define opcode_lhu 0x03
// S-type
#define opcode_sb 0x23
#define opcode_sh 0x23
#define opcode_sw 0x23
// I-type
#define opcode_addi 0x13
#define opcode_slti 0x13
#define opcode_sltiu 0x13
#define opcode_xori 0x13
#define opcode_ori 0x13
#define opcode_andi 0x13
#define opcode_slli 0x13
#define opcode_srli 0x13
#define opcode_srai 0x13
// R-type
#define opcode_add 0x33
#define opcode_sub 0x33
#define opcode_sll 0x33
#define opcode_slt 0x33
#define opcode_sltu 0x33
#define opcode_xor 0x33
#define opcode_srl 0x33
#define opcode_sra 0x33
#define opcode_or 0x33
#define opcode_and 0x33
#define opcode_fence 0x0f
#define opcode_fence_i 0x0f
// 系统调用
#define opcode_ecall 0x73
#define opcode_ebreak 0x73
#define opcode_csrrw 0x73
#define opcode_csrrs 0x73
#define opcode_csrrc 0x73
#define opcode_csrrwi 0x73
#define opcode_csrrsi 0x73
#define opcode_csrrci 0x73
#define MEM_SIZE 1024*1024
#define MEM_ADDR_BASE 0x100000000
#define zero reg[0]
#define ra reg[1]
#define sp reg[2]
#define gp reg[3]
#define tp reg[4]
#define t0 reg[5]
#define t1 reg[6]
#define t2 reg[7]
#define s0 reg[8]
#define s1 reg[9]
#define a0 reg[10]
#define a1 reg[11]
#define a2 reg[12]
#define a3 reg[13]
#define a4 reg[14]
#define a5 reg[15]
#define a6 reg[16]
#define a7 reg[17]
#define s2 reg[18]
#define s3 reg[19]
#define s4 reg[20]
#define s5 reg[21]
#define s6 reg[22]
#define s7 reg[23]
#define s8 reg[24]
#define s9 reg[25]
#define s10 reg[26]
#define s11 reg[27]
#define t3 reg[28]
#define t4 reg[29]
#define t5 reg[30]
#define t6 reg[31]
#define fp reg[8]
typedef struct {
uint32_t mem[MEM_SIZE];
uint32_t* rom;
uint32_t rom_size;
uint32_t rom_addr_base;
uint32_t pc;
uint32_t reg[32];
uint32_t csrs[4096];
}riscv_t;
#define mem_wr(addr) riscv->mem[(addr-MEM_ADDR_BASE) >> 2]
#define mem_wrb(addr) ((uint8_t *)riscv->mem)[(addr-MEM_ADDR_BASE)]
#define mem_wrh(addr) ((uint16_t *)riscv->mem)[(addr-MEM_ADDR_BASE) >> 1]
void ins_add(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] + riscv->reg[rs2];
}
void ins_addi(riscv_t* riscv, int rs1, int imm, int rd) {
if (imm & (1 << 11)) {
imm |= 0xfffff800;
}
riscv->reg[rd] = riscv->reg[rs1] + imm;
}
void ins_addiw(riscv_t* riscv, int rs1, int imm, int rd) {
if (imm & (1 << 11)) {
imm |= 0xfffff800;
}
riscv->reg[rd] = riscv->reg[rs1] + imm;
}
void ins_addw(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] + riscv->reg[rs2];
}
void ins_and(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] & riscv->reg[rs2];
}
void ins_andi(riscv_t* riscv, int rs1, int imm, int rd) {
if (imm & (1 << 11)) {
imm |= 0xfffff800;
}
riscv->reg[rd] = riscv->reg[rs1] & imm;
}
// 设置pc指针高位地址
void ins_auipc(riscv_t* riscv, int imm, int rd) {
riscv->reg[rd] = riscv->pc + imm;
}
void ins_beq(riscv_t* riscv, int rs2, int rs1, int imm) {
if (riscv->reg[rs1] == riscv->reg[rs2]) {
riscv->pc += imm;
}
}
void ins_bge(riscv_t* riscv, int rs2, int rs1, int imm) {
if ((int)riscv->reg[rs1] >= (int)riscv->reg[rs2]) {
riscv->pc += imm;
}
}
void ins_bgeu(riscv_t* riscv, int rs2, int rs1, int imm) {
if (riscv->reg[rs1] >= riscv->reg[rs2]) {
riscv->pc += imm;
}
}
void ins_blt(riscv_t* riscv, int rs2, int rs1, int imm) {
if ((int)riscv->reg[rs1] < (int)riscv->reg[rs2]) {
riscv->pc += imm;
}
}
void ins_bltu(riscv_t* riscv, int rs2, int rs1, int imm) {
if (riscv->reg[rs1] < riscv->reg[rs2]) {
riscv->pc += imm;
}
}
void ins_bne(riscv_t* riscv, int rs2, int rs1, int imm) {
if (riscv->reg[rs1] != riscv->reg[rs2]) {
riscv->pc += imm;
}
}
// 读后清除
void ins_csrrc(riscv_t* riscv, int rs1, int rd, int csr) {
uint32_t t = riscv->csrs[csr];
riscv->csrs[csr] = riscv->csrs[csr] & ~riscv->reg[rs1];
riscv->reg[rd] = t;
}
void ins_csrrs(riscv_t* riscv, int rs1, int rd, int csr) {
uint32_t t = riscv->csrs[csr];
riscv->csrs[csr] = riscv->csrs[csr] | riscv->reg[rs1];
riscv->reg[rd] = t;
}
void ins_csrrw(riscv_t* riscv, int rs1, int rd, int csr) {
uint32_t t = riscv->csrs[csr];
riscv->csrs[csr] = riscv->reg[rs1];
riscv->reg[rd] = t;
}
void ins_csrrci(riscv_t* riscv, int imm, int rd, int csr) {
uint32_t t = riscv->csrs[csr];
riscv->csrs[csr] = riscv->csrs[csr] & ~imm;
riscv->reg[rd] = t;
}
void ins_csrrsi(riscv_t* riscv, int imm, int rd, int csr) {
uint32_t t = riscv->csrs[csr];
riscv->csrs[csr] = riscv->csrs[csr] | imm;
riscv->reg[rd] = t;
}
void ins_csrrwi(riscv_t* riscv, int imm, int rd, int csr) {
uint32_t t = riscv->csrs[csr];
riscv->csrs[csr] = imm;
riscv->reg[rd] = t;
}
void ins_ecall(riscv_t* riscv) {
printf("ecall\n");
}
void ins_ebreak(riscv_t* riscv) {
printf("ebreak\n");
}
void ins_jal(riscv_t* riscv, int imm, int rd) {
riscv->reg[rd] = riscv->pc + 4;
riscv->pc += imm;
}
void ins_jalr(riscv_t* riscv, int rs1, int imm, int rd) {
uint32_t t;
t = riscv->pc + 4;
riscv->pc = (riscv->reg[rs1] + imm) & (~1);
riscv->reg[rd] = t;
}
void ins_lb(riscv_t* riscv, int rs1, int imm, int rd) {
uint32_t addr = riscv->reg[rs1] + imm;
riscv->reg[rd] = (int)mem_wrb(addr);
}
void ins_lbu(riscv_t* riscv, int rs1, int imm, int rd) {
uint32_t addr = riscv->reg[rs1] + imm;
riscv->reg[rd] = mem_wrb(addr);
}
void ins_lh(riscv_t* riscv, int rs1, int imm, int rd) {
uint32_t addr = riscv->reg[rs1] + imm;
riscv->reg[rd] = (int)mem_wrh(addr);
}
void ins_lhu(riscv_t* riscv, int rs1, int imm, int rd) {
uint32_t addr = riscv->reg[rs1] + imm;
riscv->reg[rd] = mem_wrh(addr);
}
void ins_lw(riscv_t* riscv, int rs1, int imm, int rd) {
uint32_t addr = riscv->reg[rs1] + imm;
riscv->reg[rd] = mem_wr(addr);
}
void ins_lwu(riscv_t* riscv, int rs1, int imm, int rd) {
uint32_t addr = riscv->reg[rs1] + imm;
riscv->reg[rd] = mem_wr(addr);
}
void ins_lui(riscv_t* riscv, int imm, int rd) {
riscv->reg[rd] = imm;
}
void ins_mret(riscv_t* riscv) {
riscv->pc = riscv->csrs[0x305];
printf("mret\n");
}
void ins_or(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] | riscv->reg[rs2];
}
void ins_ori(riscv_t* riscv, int rs1, int imm, int rd) {
if (imm & (1 << 11)) {
imm |= 0xfffff800;
}
riscv->reg[rd] = riscv->reg[rs1] | imm;
}
void ins_sb(riscv_t* riscv, int rs2, int rs1, int imm) {
uint32_t addr = riscv->reg[rs1] + imm;
mem_wrb(addr) = riscv->reg[rs2];
}
void ins_sh(riscv_t* riscv, int rs2, int rs1, int imm) {
uint32_t addr = riscv->reg[rs1] + imm;
mem_wrh(addr) = riscv->reg[rs2];
}
void ins_sw(riscv_t* riscv, int rs2, int rs1, int imm) {
uint32_t addr = riscv->reg[rs1] + imm;
mem_wr(addr) = riscv->reg[rs2];
}
void ins_sll(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] << riscv->reg[rs2];
}
void ins_slli(riscv_t* riscv, int rs1, int imm, int rd) {
riscv->reg[rd] = riscv->reg[rs1] << imm;
}
void ins_slliw(riscv_t* riscv, int rs1, int imm, int rd) {
riscv->reg[rd] = riscv->reg[rs1] << imm;
}
void ins_sllw(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] << riscv->reg[rs2];
}
void ins_slt(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = (int)riscv->reg[rs1] < (int)riscv->reg[rs2];
}
void ins_slti(riscv_t* riscv, int rs1, int imm, int rd) {
if (imm & (1 << 11)) {
imm |= 0xfffff800;
}
riscv->reg[rd] = (int)riscv->reg[rs1] < imm;
}
void ins_sltiu(riscv_t* riscv, int rs1, int imm, int rd) {
if (imm & (1 << 11)) {
imm |= 0xfffff800;
}
riscv->reg[rd] = riscv->reg[rs1] < (uint32_t)imm;
}
void ins_sltu(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] < riscv->reg[rs2];
}
void ins_sra(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = (int)riscv->reg[rs1] >> riscv->reg[rs2];
}
void ins_srai(riscv_t* riscv, int rs1, int imm, int rd) {
riscv->reg[rd] = (int)riscv->reg[rs1] >> imm;
}
void ins_sraiw(riscv_t* riscv, int rs1, int imm, int rd) {
riscv->reg[rd] = (int)riscv->reg[rs1] >> imm;
}
void ins_sraw(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = (int)riscv->reg[rs1] >> riscv->reg[rs2];
}
void ins_srl(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] >> riscv->reg[rs2];
}
void ins_srli(riscv_t* riscv, int rs1, int imm, int rd) {
riscv->reg[rd] = riscv->reg[rs1] >> imm;
}
void ins_srliw(riscv_t* riscv, int rs1, int imm, int rd) {
riscv->reg[rd] = riscv->reg[rs1] >> imm;
}
void ins_srlw(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] >> riscv->reg[rs2];
}
void ins_sub(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] - riscv->reg[rs2];
}
void ins_subw(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] - riscv->reg[rs2];
}
void ins_wfi(riscv_t* riscv) {
printf("wfi\n");
}
void ins_xor(riscv_t* riscv, int rs2, int rs1, int rd) {
riscv->reg[rd] = riscv->reg[rs1] ^ riscv->reg[rs2];
}
void ins_xori(riscv_t* riscv, int rs1, int imm, int rd) {
if (imm & (1 << 11)) {
imm |= 0xfffff800;
}
riscv->reg[rd] = riscv->reg[rs1] ^ imm;
}
#define imm_u_type(_ins,_imm) {\
_imm = (_ins >> 12) << 12;\
}
#define imm_j_type(_ins,_imm) {\
_imm = ((_ins >> 21) & 0x3ff) << 1;\
_imm |= ((_ins >> 20) & 1) << 11;\
_imm |= ((_ins >> 12) & 0xff) << 12;\
_imm |= ((_ins >> 31) & 1) << 20;\
}
#define imm_i_type(_ins,_imm) {\
_imm = (_ins >> 20) & 0xfff;\
}
#define imm_b_type(_ins,_imm) {\
_imm = ((_ins >> 31) & 1) << 12;\
_imm |= ((_ins >> 7) & 0x1) << 11;\
_imm |= ((_ins >> 8) & 0xf) << 1;\
_imm |= ((_ins >> 25) & 0x1f) << 5;\
}
#define imm_s_type(_ins,_imm) {\
_imm = ((_ins >> 7) & 0x1f);\
_imm |= (_ins >> 25) << 5;\
}
// 解析指令
int riscv_decode(riscv_t* riscv, uint32_t ins) {
int imm = 0;
int rs1 = (ins >> 15) & 0x1f;
int rs2 = (ins >> 20) & 0x1f;
int rd = (ins >> 7) & 0x1f;
int funct3 = (ins >> 12) & 0x7;
int funct7 = (ins >> 25) & 0x7f;
int opcode = ins & 0x7f;
switch (opcode) {
// U-type
case opcode_lui:
imm_u_type(ins, imm);
ins_lui(riscv, imm, rd);
break;
case opcode_auipc:
imm_u_type(ins, imm);
ins_auipc(riscv, imm, rd);
break;
// J-type
case opcode_jal:
imm_j_type(ins, imm);
ins_jal(riscv, imm, rd);
break;
case opcode_jalr:
// I-type
imm_i_type(ins, imm);
ins_jalr(riscv, rs1, imm, rd);
break;
case opcode_beq:
}
}
int riscv_init(riscv_t* riscv, uint32_t* rom, uint32_t rom_addr_base, uint32_t rom_size) {
riscv->rom = rom;
riscv->rom_size = rom_size;
riscv->rom_addr_base = rom_addr_base;
riscv->pc = 0;
}