def branch_unit(cur_pc,imm,jalr_sel,branch_taken,alu_result):           #输入值为当前pc(int),立即数,jalr信号,是否跳转信号,alu运算结果(全是int型)
    pc_plus_4=cur_pc+4                                                  #输出为pc_plus_imm,pc_plus_4,branch_target,pc_sel(忘了这是啥了,需要回头再看)
    pc_plus_imm=cur_pc+imm
    pc_sel = jalr_sel | (branch_taken & (alu_result%2))
    if jalr_sel==1:
        branch_target=alu_result&(2**32-2)
    else:
        branch_target=cur_pc+imm*2
    return pc_plus_imm,pc_plus_4,branch_target,pc_sel