import binascii from math import ceil from lib.gmssl.func import rotl, bytes_to_list IV = [ 1937774191, 1226093241, 388252375, 3666478592, 2842636476, 372324522, 3817729613, 2969243214, ] T_j = [ 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2043430169, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042, 2055708042 ] def sm3_ff_j(x, y, z, j): if 0 <= j and j < 16: ret = x ^ y ^ z elif 16 <= j and j < 64: ret = (x & y) | (x & z) | (y & z) return ret def sm3_gg_j(x, y, z, j): if 0 <= j and j < 16: ret = x ^ y ^ z elif 16 <= j and j < 64: # ret = (X | Y) & ((2 ** 32 - 1 - X) | Z) ret = (x & y) | ((~ x) & z) return ret def sm3_p_0(x): return x ^ (rotl(x, 9 % 32)) ^ (rotl(x, 17 % 32)) def sm3_p_1(x): return x ^ (rotl(x, 15 % 32)) ^ (rotl(x, 23 % 32)) def sm3_cf(v_i, b_i): w = [] for i in range(16): weight = 0x1000000 data = 0 for k in range(i * 4, (i + 1) * 4): data = data + b_i[k] * weight weight = int(weight / 0x100) w.append(data) for j in range(16, 68): w.append(0) w[j] = sm3_p_1(w[j - 16] ^ w[j - 9] ^ (rotl(w[j - 3], 15 % 32))) ^ (rotl(w[j - 13], 7 % 32)) ^ w[j - 6] str1 = "%08x" % w[j] w_1 = [] for j in range(0, 64): w_1.append(0) w_1[j] = w[j] ^ w[j + 4] str1 = "%08x" % w_1[j] a, b, c, d, e, f, g, h = v_i for j in range(0, 64): ss_1 = rotl( ((rotl(a, 12 % 32)) + e + (rotl(T_j[j], j % 32))) & 0xffffffff, 7 % 32 ) ss_2 = ss_1 ^ (rotl(a, 12 % 32)) tt_1 = (sm3_ff_j(a, b, c, j) + d + ss_2 + w_1[j]) & 0xffffffff tt_2 = (sm3_gg_j(e, f, g, j) + h + ss_1 + w[j]) & 0xffffffff d = c c = rotl(b, 9 % 32) b = a a = tt_1 h = g g = rotl(f, 19 % 32) f = e e = sm3_p_0(tt_2) a, b, c, d, e, f, g, h = map( lambda x: x & 0xFFFFFFFF, [a, b, c, d, e, f, g, h]) v_j = [a, b, c, d, e, f, g, h] return [v_j[i] ^ v_i[i] for i in range(8)] def sm3_hash(msg): # print(msg) len1 = len(msg) reserve1 = len1 % 64 msg.append(0x80) reserve1 = reserve1 + 1 # 56-64, add 64 byte range_end = 56 if reserve1 > range_end: range_end = range_end + 64 for i in range(reserve1, range_end): msg.append(0x00) bit_length = (len1) * 8 bit_length_str = [bit_length % 0x100] for i in range(7): bit_length = int(bit_length / 0x100) bit_length_str.append(bit_length % 0x100) for i in range(8): msg.append(bit_length_str[7 - i]) group_count = round(len(msg) / 64) B = [] for i in range(0, group_count): B.append(msg[i * 64:(i + 1) * 64]) V = [] V.append(IV) for i in range(0, group_count): V.append(sm3_cf(V[i], B[i])) y = V[i + 1] result = "" for i in y: result = '%s%08x' % (result, i) return result def sm3_kdf(z, klen): # z为16进制表示的比特串(str),klen为密钥长度(单位byte) klen = int(klen) ct = 0x00000001 rcnt = ceil(klen / 32) zin = [i for i in bytes.fromhex(z.decode('utf8'))] ha = "" for i in range(rcnt): msg = zin + [i for i in binascii.a2b_hex(('%08x' % ct).encode('utf8'))] ha = ha + sm3_hash(msg) ct += 1 return ha[0: klen * 2] def hash(data): return sm3_hash(bytes_to_list(data)).encode()