#include "riscv.h" /* 每个指令长度为32bit,分为6个类型 |--------------+------------+-----------+-------------+--------------+---------------| | [31-25] | [24-20] | [19-15] | [14-12] | [11-7] | [6-0] | |--------------+------------+-----------+-------------+--------------+---------------| | funct7 | rs2 | rs1 | funct3 | rd | opcode | R-type 计算指令 |--------------+------------+-----------+-------------+--------------+---------------| | imm[11:0] | rs1 | funct3 | rd | opcode | I-type 加载指令 |--------------+------------+-----------+-------------+--------------+---------------| | imm[11:5] | rs2 | rs1 | funct3 | imm[4:0] | opcode | S-type 存储指令 |--------------+------------+-----------+-------------+--------------+---------------| |imm[12] [10:5]| rs2 | rs1 | funct3 | imm[4:1][11] | opcode | B-type 分支指令 |--------------+------------+-----------+-------------+--------------+---------------| | imm[31:12] | rd | opcode | U-type 立即数指令 |-----------------------------------------------------+--------------+---------------| | imm[20] [10:1] [11] [19:12] | rd | opcode | J-type 跳转指令 |-----------------------------------------------------+--------------+---------------| */ // 定义opcode #define opcode_lui 0x37 // U-type #define opcode_auipc 0x17 // U-type #define opcode_jal 0x6f // J-type #define opcode_jalr 0x67 // I-type // B-type #define opcode_beq 0x63 #define opcode_bne 0x63 #define opcode_blt 0x63 #define opcode_bge 0x63 #define opcode_bltu 0x63 #define opcode_bgeu 0x63 // I-type #define opcode_lb 0x03 #define opcode_lh 0x03 #define opcode_lw 0x03 #define opcode_lbu 0x03 #define opcode_lhu 0x03 // S-type #define opcode_sb 0x23 #define opcode_sh 0x23 #define opcode_sw 0x23 // I-type #define opcode_addi 0x13 #define opcode_slti 0x13 #define opcode_sltiu 0x13 #define opcode_xori 0x13 #define opcode_ori 0x13 #define opcode_andi 0x13 #define opcode_slli 0x13 #define opcode_srli 0x13 #define opcode_srai 0x13 // R-type #define opcode_add 0x33 #define opcode_sub 0x33 #define opcode_sll 0x33 #define opcode_slt 0x33 #define opcode_sltu 0x33 #define opcode_xor 0x33 #define opcode_srl 0x33 #define opcode_sra 0x33 #define opcode_or 0x33 #define opcode_and 0x33 #define opcode_fence 0x0f #define opcode_fence_i 0x0f // 系统调用 #define opcode_ecall 0x73 #define opcode_ebreak 0x73 #define opcode_csrrw 0x73 #define opcode_csrrs 0x73 #define opcode_csrrc 0x73 #define opcode_csrrwi 0x73 #define opcode_csrrsi 0x73 #define opcode_csrrci 0x73 #define MEM_SIZE 1024*1024 #define MEM_ADDR_BASE 0x100000000 #define zero reg[0] #define ra reg[1] #define sp reg[2] #define gp reg[3] #define tp reg[4] #define t0 reg[5] #define t1 reg[6] #define t2 reg[7] #define s0 reg[8] #define s1 reg[9] #define a0 reg[10] #define a1 reg[11] #define a2 reg[12] #define a3 reg[13] #define a4 reg[14] #define a5 reg[15] #define a6 reg[16] #define a7 reg[17] #define s2 reg[18] #define s3 reg[19] #define s4 reg[20] #define s5 reg[21] #define s6 reg[22] #define s7 reg[23] #define s8 reg[24] #define s9 reg[25] #define s10 reg[26] #define s11 reg[27] #define t3 reg[28] #define t4 reg[29] #define t5 reg[30] #define t6 reg[31] #define fp reg[8] typedef struct { uint32_t mem[MEM_SIZE]; uint32_t* rom; uint32_t rom_size; uint32_t rom_addr_base; uint32_t pc; uint32_t reg[32]; uint32_t csrs[4096]; }riscv_t; #define mem_wr(addr) riscv->mem[(addr-MEM_ADDR_BASE) >> 2] #define mem_wrb(addr) ((uint8_t *)riscv->mem)[(addr-MEM_ADDR_BASE)] #define mem_wrh(addr) ((uint16_t *)riscv->mem)[(addr-MEM_ADDR_BASE) >> 1] void ins_add(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] + riscv->reg[rs2]; } void ins_addi(riscv_t* riscv, int rs1, int imm, int rd) { if (imm & (1 << 11)) { imm |= 0xfffff800; } riscv->reg[rd] = riscv->reg[rs1] + imm; } void ins_addiw(riscv_t* riscv, int rs1, int imm, int rd) { if (imm & (1 << 11)) { imm |= 0xfffff800; } riscv->reg[rd] = riscv->reg[rs1] + imm; } void ins_addw(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] + riscv->reg[rs2]; } void ins_and(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] & riscv->reg[rs2]; } void ins_andi(riscv_t* riscv, int rs1, int imm, int rd) { if (imm & (1 << 11)) { imm |= 0xfffff800; } riscv->reg[rd] = riscv->reg[rs1] & imm; } // 设置pc指针高位地址 void ins_auipc(riscv_t* riscv, int imm, int rd) { riscv->reg[rd] = riscv->pc + imm; } void ins_beq(riscv_t* riscv, int rs2, int rs1, int imm) { if (riscv->reg[rs1] == riscv->reg[rs2]) { riscv->pc += imm; } } void ins_bge(riscv_t* riscv, int rs2, int rs1, int imm) { if ((int)riscv->reg[rs1] >= (int)riscv->reg[rs2]) { riscv->pc += imm; } } void ins_bgeu(riscv_t* riscv, int rs2, int rs1, int imm) { if (riscv->reg[rs1] >= riscv->reg[rs2]) { riscv->pc += imm; } } void ins_blt(riscv_t* riscv, int rs2, int rs1, int imm) { if ((int)riscv->reg[rs1] < (int)riscv->reg[rs2]) { riscv->pc += imm; } } void ins_bltu(riscv_t* riscv, int rs2, int rs1, int imm) { if (riscv->reg[rs1] < riscv->reg[rs2]) { riscv->pc += imm; } } void ins_bne(riscv_t* riscv, int rs2, int rs1, int imm) { if (riscv->reg[rs1] != riscv->reg[rs2]) { riscv->pc += imm; } } // 读后清除 void ins_csrrc(riscv_t* riscv, int rs1, int rd, int csr) { uint32_t t = riscv->csrs[csr]; riscv->csrs[csr] = riscv->csrs[csr] & ~riscv->reg[rs1]; riscv->reg[rd] = t; } void ins_csrrs(riscv_t* riscv, int rs1, int rd, int csr) { uint32_t t = riscv->csrs[csr]; riscv->csrs[csr] = riscv->csrs[csr] | riscv->reg[rs1]; riscv->reg[rd] = t; } void ins_csrrw(riscv_t* riscv, int rs1, int rd, int csr) { uint32_t t = riscv->csrs[csr]; riscv->csrs[csr] = riscv->reg[rs1]; riscv->reg[rd] = t; } void ins_csrrci(riscv_t* riscv, int imm, int rd, int csr) { uint32_t t = riscv->csrs[csr]; riscv->csrs[csr] = riscv->csrs[csr] & ~imm; riscv->reg[rd] = t; } void ins_csrrsi(riscv_t* riscv, int imm, int rd, int csr) { uint32_t t = riscv->csrs[csr]; riscv->csrs[csr] = riscv->csrs[csr] | imm; riscv->reg[rd] = t; } void ins_csrrwi(riscv_t* riscv, int imm, int rd, int csr) { uint32_t t = riscv->csrs[csr]; riscv->csrs[csr] = imm; riscv->reg[rd] = t; } void ins_ecall(riscv_t* riscv) { printf("ecall\n"); } void ins_ebreak(riscv_t* riscv) { printf("ebreak\n"); } void ins_jal(riscv_t* riscv, int imm, int rd) { riscv->reg[rd] = riscv->pc + 4; riscv->pc += imm; } void ins_jalr(riscv_t* riscv, int rs1, int imm, int rd) { uint32_t t; t = riscv->pc + 4; riscv->pc = (riscv->reg[rs1] + imm) & (~1); riscv->reg[rd] = t; } void ins_lb(riscv_t* riscv, int rs1, int imm, int rd) { uint32_t addr = riscv->reg[rs1] + imm; riscv->reg[rd] = (int)mem_wrb(addr); } void ins_lbu(riscv_t* riscv, int rs1, int imm, int rd) { uint32_t addr = riscv->reg[rs1] + imm; riscv->reg[rd] = mem_wrb(addr); } void ins_lh(riscv_t* riscv, int rs1, int imm, int rd) { uint32_t addr = riscv->reg[rs1] + imm; riscv->reg[rd] = (int)mem_wrh(addr); } void ins_lhu(riscv_t* riscv, int rs1, int imm, int rd) { uint32_t addr = riscv->reg[rs1] + imm; riscv->reg[rd] = mem_wrh(addr); } void ins_lw(riscv_t* riscv, int rs1, int imm, int rd) { uint32_t addr = riscv->reg[rs1] + imm; riscv->reg[rd] = mem_wr(addr); } void ins_lwu(riscv_t* riscv, int rs1, int imm, int rd) { uint32_t addr = riscv->reg[rs1] + imm; riscv->reg[rd] = mem_wr(addr); } void ins_lui(riscv_t* riscv, int imm, int rd) { riscv->reg[rd] = imm; } void ins_mret(riscv_t* riscv) { riscv->pc = riscv->csrs[0x305]; printf("mret\n"); } void ins_or(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] | riscv->reg[rs2]; } void ins_ori(riscv_t* riscv, int rs1, int imm, int rd) { if (imm & (1 << 11)) { imm |= 0xfffff800; } riscv->reg[rd] = riscv->reg[rs1] | imm; } void ins_sb(riscv_t* riscv, int rs2, int rs1, int imm) { uint32_t addr = riscv->reg[rs1] + imm; mem_wrb(addr) = riscv->reg[rs2]; } void ins_sh(riscv_t* riscv, int rs2, int rs1, int imm) { uint32_t addr = riscv->reg[rs1] + imm; mem_wrh(addr) = riscv->reg[rs2]; } void ins_sw(riscv_t* riscv, int rs2, int rs1, int imm) { uint32_t addr = riscv->reg[rs1] + imm; mem_wr(addr) = riscv->reg[rs2]; } void ins_sll(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] << riscv->reg[rs2]; } void ins_slli(riscv_t* riscv, int rs1, int imm, int rd) { riscv->reg[rd] = riscv->reg[rs1] << imm; } void ins_slliw(riscv_t* riscv, int rs1, int imm, int rd) { riscv->reg[rd] = riscv->reg[rs1] << imm; } void ins_sllw(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] << riscv->reg[rs2]; } void ins_slt(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = (int)riscv->reg[rs1] < (int)riscv->reg[rs2]; } void ins_slti(riscv_t* riscv, int rs1, int imm, int rd) { if (imm & (1 << 11)) { imm |= 0xfffff800; } riscv->reg[rd] = (int)riscv->reg[rs1] < imm; } void ins_sltiu(riscv_t* riscv, int rs1, int imm, int rd) { if (imm & (1 << 11)) { imm |= 0xfffff800; } riscv->reg[rd] = riscv->reg[rs1] < (uint32_t)imm; } void ins_sltu(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] < riscv->reg[rs2]; } void ins_sra(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = (int)riscv->reg[rs1] >> riscv->reg[rs2]; } void ins_srai(riscv_t* riscv, int rs1, int imm, int rd) { riscv->reg[rd] = (int)riscv->reg[rs1] >> imm; } void ins_sraiw(riscv_t* riscv, int rs1, int imm, int rd) { riscv->reg[rd] = (int)riscv->reg[rs1] >> imm; } void ins_sraw(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = (int)riscv->reg[rs1] >> riscv->reg[rs2]; } void ins_srl(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] >> riscv->reg[rs2]; } void ins_srli(riscv_t* riscv, int rs1, int imm, int rd) { riscv->reg[rd] = riscv->reg[rs1] >> imm; } void ins_srliw(riscv_t* riscv, int rs1, int imm, int rd) { riscv->reg[rd] = riscv->reg[rs1] >> imm; } void ins_srlw(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] >> riscv->reg[rs2]; } void ins_sub(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] - riscv->reg[rs2]; } void ins_subw(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] - riscv->reg[rs2]; } void ins_wfi(riscv_t* riscv) { printf("wfi\n"); } void ins_xor(riscv_t* riscv, int rs2, int rs1, int rd) { riscv->reg[rd] = riscv->reg[rs1] ^ riscv->reg[rs2]; } void ins_xori(riscv_t* riscv, int rs1, int imm, int rd) { if (imm & (1 << 11)) { imm |= 0xfffff800; } riscv->reg[rd] = riscv->reg[rs1] ^ imm; } #define imm_u_type(_ins,_imm) {\ _imm = (_ins >> 12) << 12;\ } #define imm_j_type(_ins,_imm) {\ _imm = ((_ins >> 21) & 0x3ff) << 1;\ _imm |= ((_ins >> 20) & 1) << 11;\ _imm |= ((_ins >> 12) & 0xff) << 12;\ _imm |= ((_ins >> 31) & 1) << 20;\ } #define imm_i_type(_ins,_imm) {\ _imm = (_ins >> 20) & 0xfff;\ } #define imm_b_type(_ins,_imm) {\ _imm = ((_ins >> 31) & 1) << 12;\ _imm |= ((_ins >> 7) & 0x1) << 11;\ _imm |= ((_ins >> 8) & 0xf) << 1;\ _imm |= ((_ins >> 25) & 0x1f) << 5;\ } #define imm_s_type(_ins,_imm) {\ _imm = ((_ins >> 7) & 0x1f);\ _imm |= (_ins >> 25) << 5;\ } #define imm_csr(_ins,_imm) {\ _imm = (_ins >> 20) & 0xfff;\ } // 解析指令 int riscv_decode(riscv_t* riscv, uint32_t ins) { int imm = 0; int rs1 = (ins >> 15) & 0x1f; int rs2 = (ins >> 20) & 0x1f; int rd = (ins >> 7) & 0x1f; int funct3 = (ins >> 12) & 0x7; int funct7 = (ins >> 25) & 0x7f; int opcode = ins & 0x7f; switch (opcode) { // U-type case opcode_lui: imm_u_type(ins, imm); ins_lui(riscv, imm, rd); break; case opcode_auipc: imm_u_type(ins, imm); ins_auipc(riscv, imm, rd); break; // J-type case opcode_jal: imm_j_type(ins, imm); ins_jal(riscv, imm, rd); break; case opcode_jalr: // I-type imm_i_type(ins, imm); ins_jalr(riscv, rs1, imm, rd); break; case opcode_beq: imm_b_type(ins, imm); switch(funct3) { case 0x0: ins_beq(riscv, rs2, rs1, imm); break; case 0x1: ins_bne(riscv, rs2, rs1, imm); break; case 0x4: ins_blt(riscv, rs2, rs1, imm); break; case 0x5: ins_bge(riscv, rs2, rs1, imm); break; case 0x6: ins_bltu(riscv, rs2, rs1, imm); break; case 0x7: ins_bgeu(riscv, rs2, rs1, imm); break; default: break; } break; case opcode_lb: imm_i_type(imm, ins); switch (funct3) { case 0x0: ins_lb(riscv, rs1, imm, rd); break; case 0x1: ins_lh(riscv, rs1, imm, rd); break; case 0x2: ins_lw(riscv, rs1, imm, rd); break; case 0x4: ins_lbu(riscv, rs1, imm, rd); break; case 0x5: ins_lhu(riscv, rs1, imm, rd); break; default: break; } break; case opcode_sb: switch(funct3){ case 0x0: ins_sb(riscv, rs2, rs1, imm); break; case 0x1: ins_sh(riscv, rs2, rs1, imm); break; case 0x2: ins_sw(riscv, rs2, rs1, imm); break; default: break; } break; case opcode_addi: imm_i_type(ins,imm); switch (funct3) { case 0x0: ins_addi(riscv, rs1, imm, rd); break; case 0x2: ins_slti(riscv, rs1, imm, rd); break; case 0x3: ins_sltiu(riscv, rs1, imm, rd); break; case 0x4: ins_xori(riscv, rs1, imm, rd); break; case 0x6: ins_ori(riscv, rs1, imm, rd); break; case 0x7: ins_andi(riscv, rs1, imm, rd); break; case 0x1: imm = imm&0x1f; ins_slli(riscv, rs1, imm, rd); break; case 0x5: imm = imm&0x1f; if(funct7 == 0x20){ ins_srai(riscv, rs1, imm, rd); }else{ ins_srli(riscv, rs1, imm, rd); } break; default: break; } break; case opcode_add: switch (funct3) { case 0x0: if(funct7 == 0x20){ ins_sub(riscv, rs2, rs1, rd); }else{ ins_add(riscv, rs2, rs1, rd); } break; case 0x1: ins_sll(riscv, rs2, rs1, rd); break; case 0x2: ins_slt(riscv, rs2, rs1, rd); break; case 0x3: ins_sltu(riscv, rs2, rs1, rd); break; case 0x4: ins_xor(riscv, rs2, rs1, rd); break; case 0x5: if (funct7 == 0x20) { ins_sra(riscv, rs2, rs1, rd); }else{ ins_srl(riscv, rs2, rs1, rd); } break; case 0x6: ins_or(riscv, rs2, rs1, rd); break; case 0x7: ins_and(riscv, rs2, rs1, rd); break; default: break; } break; case opcode_ecall: imm_csr(ins,imm); switch (funct3) { case 0x0: ins_ebreak(riscv); break; case 0x1: ins_csrrw(riscv, rs1, rd, imm); break; case 0x2: ins_csrrs(riscv, rs1, rd, imm); break; case 0x3: ins_csrrc(riscv, rs1, rd, imm); break; case 0x5: // rs1 保存的是zimm 值 ins_csrrwi(riscv, rs1, rd, imm); break; case 0x6: ins_csrrsi(riscv, rs1, rd, imm); break; case 0x7: ins_csrrci(riscv, rs1, rd, imm); break; default: break; } break; default: break; } return 0; } int riscv_init(riscv_t* riscv, uint32_t* rom, uint32_t rom_addr_base, uint32_t rom_size) { riscv->rom = rom; riscv->rom_size = rom_size; riscv->rom_addr_base = rom_addr_base; riscv->pc = 0; } int riscv_run(riscv_t* riscv) { int ret = 0; while(1) { uint32_t instr = riscv->rom[riscv->pc >> 2]; riscv->pc += 4; if(riscv->pc >= riscv->rom_addr_base + riscv->rom_size){ printf("riscv run out of rom"); break; } ret=riscv_decode(riscv, instr); if(ret){ break; } } printf("riscv run end\n"); } #include "stdio.h" #include "stdlib.h" #include "errno.h" long get_file_size(FILE *stream) { long file_size = -1; long cur_offset = ftell(stream); // 获取当前偏移位置 if (cur_offset == -1) { printf("ftell failed :%s\n", strerror(errno)); return -1; } if (fseek(stream, 0, SEEK_END) != 0) { // 移动文件指针到文件末尾 printf("fseek failed: %s\n", strerror(errno)); return -1; } file_size = ftell(stream); // 获取此时偏移值,即文件大小 if (file_size == -1) { printf("ftell failed :%s\n", strerror(errno)); } if (fseek(stream, cur_offset, SEEK_SET) != 0) { // 将文件指针恢复初始位置 printf("fseek failed: %s\n", strerror(errno)); return -1; } return file_size; } int thread_fun(void* t) { riscv_t riscv={0}; FILE *file=fopen("riscv.bin", "rb" ); if(file==NULL) { printf("open file error\n"); return -1; } riscv.rom_size=get_file_size(file); riscv.rom=calloc((riscv.rom_size+3)/4,4); fread(riscv.rom, 1, riscv.rom_size, file); fclose(file); riscv_init(&riscv,riscv.rom,0x80000000,riscv.rom_size); }