from Crypto.Cipher import AES
from gmssl import func  # 该模块用于bytes与列表之间的转换


# 密钥设置,请设置为16.24.32
# ENCKEY = '1234567812345678'

def po_main(plain=None, iv=None, enckey=None, attacker_enc=None, plain_want=None):
    """
    函数作用:主函数,用于基本信息的输入和整体提示信息的输出
            调用其他函数完成攻击流程
    """
    print("=== Padding Oricle Attrack POC(CBC-MODE) ===")

    # 服务端默认设置
    cipher = "aes"  # 加密方式(当前代码只支持AES)
    block_size = 16  # 分组大小(AES默认为16)

    # 测试设置(实际使用时请注释掉)
    # iv = "1234567812345678".encode('utf-8')            # 初始向量
    # plain = "111111111111111122222".encode('utf-8')    # 明文
    # plain_want = "opaas".encode('utf-8')               # 想要获取密文的明文

    iv = iv.encode('utf-8')
    # 判断IV是否长度是否符合要求
    if len(iv) != block_size:
        print("[-] IV must be" + str(block_size) + "bytes long!")
        return False

    if plain:
        plain = plain.encode('utf-8')

        print("=== Generate Target Ciphertext ===")
        # 调用函数加密
        ciphertext = encrypt(plain, iv, cipher, enckey)
        if not ciphertext:
            print("[-] Encrypt Error!")
            return False

        # 输出一些基本信息
        print("[+] plaintext is: " + str(plain, encoding="utf-8"))
        print("[+] iv is: ", iv.decode("utf-8"))
        print("[+] ciphertext is: ", ciphertext)

        return ciphertext

    if attacker_enc:
        print(attacker_enc)
        print(type(attacker_enc))

    if plain_want:
        plain_want = plain_want.encode('utf-8')

    # 开始进行Padding Oracle攻击
    print("=== Start Padding Oracle Decrypt ===")
    print("[+] Choosing Cipher: " + cipher.upper())
    guess = padding_oracle_decrypt(cipher, attacker_enc, iv, block_size=block_size, enckey=enckey)
    if guess:
        print("[+] Guess intermediary value is: ", guess["intermediary"])
        print("[+] plaintext = intermediary_value XOR original_IV")
        print("[+] Guess plaintext is: ", guess["plaintext"])

        # 开始利用获取的中间值求解指定明文加密后的密文
        if plain_want:
            print("=== Start Padding Oracle Encrypt ===")
            print("[+] plaintext want to encrypt is: ", plain_want)
            print("[+] Choosing Cipher: " + cipher.upper())
            en = padding_oracle_encrypt(cipher, attacker_enc, plain_want, iv, block_size=block_size, enckey=enckey)
            if en:
                print("[+] Encrypt Success!")
                print("[+] The ciphertext you want is: ", en[block_size:])
                print("[+] IV is: ", iv)
                print("=== Let's verify the custom encrypt result ===")
                print("[+] Decrypt of ciphertext ", en[block_size:], " is:")
                de = decrypt(en[block_size:], en[:block_size], cipher, enckey)
                if de == add_PKCS5_padding(plain_want, block_size):
                    print(de)
                    print("[+] Bingo!")

                else:
                    print("[-] It seems something wrong happened!")
                    return False
        return guess["plaintext"], en[block_size:]

    else:
        return False


def padding_oracle_encrypt(cipher, ciphertext, plaintext, iv, enckey, block_size=16):
    """"
    函数功能:完成指定明文加密的功能
    输入:
        cipher:加密类型
        ciphertext:已知的密文
        plaintext:需要加密的明文
        iv:初始向量
        block_size:分组大小
    输出:
        guess_cipher:猜测的密文(包含IV和密文)
    注意:这个函数功能的实现是在前面攻击的基础上实现的,即已经获取到了中间值
    """

    guess_cipher = ciphertext[0 - block_size:]

    plaintext = add_PKCS5_padding(plaintext, block_size)

    print("[*] After padding, plaintext becomes to: ", plaintext)

    block = len(plaintext)

    iv_nouse = iv  # 用不到,只需要中间值

    prev_cipher = ciphertext[0 - block_size:]  # init with the last cipher block

    while block > 0:
        tmp = padding_oracle_decrypt_block(cipher, prev_cipher, iv_nouse, block_size=block_size, debug=False, enckey=enckey)
        prev_cipher = xor_str(plaintext[block - block_size:block], tmp["intermediary"])
        guess_cipher = prev_cipher + guess_cipher.decode(encoding="ISO-8859-1")
        block = block - block_size
    guess_cipher = guess_cipher.encode(encoding="ISO-8859-1")
    return guess_cipher


# 一个简单的异或操作,用于求解明文
def xor(first, second):
    return bytearray(x ^ y for x, y in zip(first, second))


def padding_oracle_decrypt(cipher, ciphertext, iv, enckey, block_size=8, debug=True):
    """"
    函数功能:对密文进行分组,并调用函数攻击每一个分组以获取中间值
    输入:
        cipher:加密类型
        ciphertext:已知的密文
        iv:初始向量
        block_size:分组大小
    输出:
        result:获取的中间值和明文信息
    注意:这个函数不是主要的攻击函数,而是完成分组和调用攻击函数的功能
    """

    # 将密文进行分组
    cipher_block = split_cipher_block(ciphertext, block_size)

    if cipher_block:
        result = {}
        result["intermediary"] = ''
        result["plaintext"] = ''
        counter = 0
        for c in cipher_block:
            if debug:
                print("[*] Now try to decrypt block " + str(counter))
                print("[*] Block " + str(counter) + "'s ciphertext is: ", c)

            # 调用真正的攻击函数(每一次攻击一个分组)
            guess = padding_oracle_decrypt_block(cipher, c, iv, block_size=block_size, debug=debug, enckey=enckey)

            if guess:
                iv = c
                guess_inter = guess["intermediary"].decode(encoding="ISO-8859-1")
                result["intermediary"] += guess_inter
                guess_plain = guess["plaintext"].decode(encoding="ISO-8859-1")
                result["plaintext"] += guess_plain
                if debug:
                    print("[+] Block " + str(counter) + " decrypt!")
                    print("[+] intermediary value is: ", guess["intermediary"])
                    print("[+] The plaintext of block " + str(counter) + " is: ", guess["plaintext"])
                counter = counter + 1
            else:
                print("[-] padding oracle decrypt error!")
                return False
        # 返回值的格式要进行转换
        result["intermediary"] = result["intermediary"].encode(encoding="ISO-8859-1")
        result["plaintext"] = result["plaintext"].encode(encoding="ISO-8859-1")
        return result
    else:
        print("[-] ciphertext's block_size is incorrect!")
        return False


def padding_oracle_decrypt_block(cipher, ciphertext, iv, enckey, block_size=16, debug=True):
    """"
    函数功能:对单个分组进行攻击。获取中间值和明文
    输入:
        cipher:加密类型
        ciphertext:已知的密文(分组)
        iv:初始向量
        block_size:分组大小
    输出:
        result:获取的中间值和明文信息(单个分组)
    注意:这个函数不是主要的攻击函数,而是完成分组和调用攻击函数的功能
    """
    result = {}
    intermediary = bytearray(16)  # 中间值

    iv_p = bytearray(16)  # 用于猜测的IV向量

    # 通过for循环实现猜测的功能
    for i in range(1, block_size + 1):
        for b in range(0, 256):
            iv_p[16 - i] = b
            iv_bytes = bytes(iv_p)
            plain = decrypt(ciphertext, iv_bytes, cipher, enckey)

            # 最关键的一步,用于模拟服务器的回显,即padding是否符合要求
            if check_PKCS5_padding(plain, i):
                if debug:
                    print("[*] Try IV: ", func.bytes_to_list(iv_bytes))
                    print("[*] Found padding oracle: ", hex_s(plain))
                intermediary[16 - i] = i ^ b
        for m in range(1, i + 1):
            iv_p[16 - m] = (i + 1) ^ intermediary[16 - m]
    plain = xor(iv, intermediary)
    result["plaintext"] = plain
    result["intermediary"] = bytes(intermediary)
    return result


# 一个简单的分组函数
def split_cipher_block(ciphertext, block_size=8):
    if len(ciphertext) % block_size != 0:
        return False
    result = []
    length = 0
    while length < len(ciphertext):
        # 每一个分组都是一个字符串
        result.append(ciphertext[length:length + block_size])
        length += block_size
    return result


def check_PKCS5_padding(plain, p):
    """"
    函数功能:服务器回显判断
    输入:
        plain:明文
        p:正在猜测的位数
    输出:
        True or False
    注意:这个函数用于模拟服务端,给出回显
    """
    if len(plain) % 16 != 0:
        return False

    # 字符串反向,为了下面便于比较
    plain = plain[::-1]
    ch = 0
    found = 0  # 相等的位数

    while ch < p:
        if bytes(plain[ch]) == bytes(p):
            found += 1
        ch += 1

    if found == p:
        return True
    else:
        return False


def add_PKCS5_padding(plaintext, block_size):
    """"
    函数功能:PKCS5填充
    输入:
        plaintext:明文
        block_size:分组大小
    输出:
        plaintext:填充后的明文
    注意:这个函数用于填充缺失值,是本次攻击可以实施的本源
    """
    s = ''

    if len(plaintext) % block_size == 0:
        return plaintext

    if len(plaintext) < block_size:
        padding = block_size - len(plaintext)

    else:
        padding = block_size - (len(plaintext) % block_size)
    plaintext = str(plaintext, encoding="utf-8")
    for i in range(0, padding):
        plaintext += chr(padding)
    plaintext = bytes(plaintext, encoding="utf-8")
    return plaintext


def decrypt(ciphertext, iv, cipher, enckey):
    """"
    函数功能:解密密文(AES)
    输入:
        ciphertext:密文
        iv:初始向量
        cipher:加解密方式
    输出:
        o.decrypt(ciphertext):解密后的明文
    注意:这个函数在本次实验里主要是用于验证密文的有效性
    """

    key = enckey
    if cipher.lower() == "aes":
        o = AES.new(key, AES.MODE_CBC, iv)
    else:
        return False

    if len(iv) % 16 != 0:
        return False

    if len(ciphertext) % 16 != 0:
        return False

    return o.decrypt(ciphertext)


def encrypt(plaintext, iv, cipher, enckey):
    """"
    函数功能:加密明文(AES)
    输入:
        plaintext:明文
        iv:初始向量
        cipher:加解密方式
    输出:
        o.encrypt(plaintext):加密后的密文
    注意:这个函数在本次实验里主要是用于初始服务端加密以获取到密文
    """
    key = enckey

    if cipher.lower() == "aes":

        if len(key) != 16 and len(key) != 24 and len(key) != 32:
            print("[-] AES key must be 16/24/32 bytes long!")
            return False

        o = AES.new(key, AES.MODE_CBC, iv)
    else:
        return False

    plaintext = add_PKCS5_padding(plaintext, len(iv))
    return o.encrypt(plaintext)


# 这个函数主要用于第二段攻击时某个值的计算
def xor_str(a, b):
    if len(a) != len(b):
        return False

    c = []
    for i in range(0, len(a)):
        c.append(a[i] ^ b[i])
    c = bytes(c)
    return c.decode(encoding="ISO-8859-1")


# 可有可无
def hex_s(str2):
    re = []
    re = func.bytes_to_list(str2)
    return re


if __name__ == "__main__":
    po_main()