from typing import Tuple def branch_unit( cur_pc: int, imm: int, jalr_sel: int, branch_taken: int, alu_result: int ) -> Tuple[int, int, int, int]: # 输入值为当前pc(int),立即数,jalr信号,是否跳转信号,alu运算结果(全是int型) # 输出为pc_plus_imm,pc_plus_4,branch_target,pc_sel(忘了这是啥了,需要回头再看) branch_taken = int(branch_taken) jalr_sel = int(jalr_sel) pc_plus_4 = cur_pc + 4 pc_plus_imm = cur_pc + imm pc_sel = jalr_sel | (branch_taken & (alu_result % 2)) if jalr_sel == 1: branch_target = alu_result & (2 ** 32 - 2) else: branch_target = cur_pc + imm * 2 return pc_plus_imm, pc_plus_4, branch_target, pc_sel