diff --git a/shor/shor.py b/shor/shor.py new file mode 100644 index 0000000..198e763 --- /dev/null +++ b/shor/shor.py @@ -0,0 +1,298 @@ +from pyqpanda import * +import math +import matplotlib.pyplot as plt + + +def MAJ(a, b, c): + circ = QCircuit() + circ.insert(CNOT(c, b)) + circ.insert(CNOT(c, a)) + circ.insert(Toffoli(a, b, c)) + return circ + + +def UMA(a, b, c): + circ = QCircuit() + circ.insert(Toffoli(a, b, c)) + circ.insert(CNOT(c, a)) + circ.insert(CNOT(a, b)) + return circ + + +def MAJ2(a, b, c): + lena = len(a) + lenb = len(b) + if a == 0: + raise RuntimeError("Empty List") + if not lena == lenb: + raise RuntimeError("Length error") + + circ = QCircuit() + circ.insert(MAJ(c, a[0], b[0])) + for i in range(1, lena): + circ.insert(MAJ(b[i-1], a[i], b[i])) + return circ + + +def Adder(a, b, c): + lena = len(a) + lenb = len(b) + if a == 0: + raise RuntimeError("Empty List") + if not lena == lenb: + raise RuntimeError("Length error") + + circ = QCircuit() + circ.insert(MAJ(c, a[0], b[0])) + + for i in range(1, lena): + circ.insert(MAJ(b[i-1], a[i], b[i])) + + for i in range(lena-1, 0, -1): + circ.insert(MAJ(b[i-1], a[i], b[i])) + + circ.insert(UMA(c, a[0], b[0])) + return circ + + +def isCarry(a, b, c, carry): + circ = QCircuit() + circ.insert(MAJ2(a, b, c)) + circ.insert(CNOT(b[-1], carry)) + circ.insert(MAJ2(a, b, c).dagger()) + return circ + + +def bindData(qlist, data): + i = 0 + circ = QCircuit() + while data >= 1: + if data & 1 == 1: + circ.insert(X(qlist[i])) + data >>= 1 + i += 1 + return circ + + +def constModAdd(qa, C, M, qb, qs1): + circ = QCircuit() + q_num = len(qa) + tmp_value = (1 << q_num)-M+C + + circ.insert(bindData(qb, tmp_value))\ + .insert(isCarry(qa, qb, qs1[1], qs1[0]))\ + .insert(bindData(qb, tmp_value)) + + tmp_circ1 = QCircuit() + + tmp_circ1.insert(bindData(qb, tmp_value)) + tmp_circ1.insert(Adder(qa, qb, qs1[1])) + tmp_circ1.insert(bindData(qb, tmp_value)) + tmp_circ1 = tmp_circ1.control([qs1[0]]) + + circ.insert(tmp_circ1) + circ.insert(X(qs1[0])) + + tmp_circ2 = QCircuit() + tmp_circ2.insert(bindData(qb, C)) + tmp_circ2.insert(Adder(qa, qb, qs1[1])) + tmp_circ2.insert(bindData(qb, C)) + tmp_circ2 = tmp_circ2.control([qs1[0]]) + + circ.insert(tmp_circ2) + circ.insert(X(qs1[0])) + + tmp_value = (1 << q_num)-C + + circ.insert(bindData(qb, tmp_value))\ + .insert(isCarry(qa, qb, qs1[1], qs1[0]))\ + .insert(bindData(qb, tmp_value))\ + .insert(X(qs1[0])) + + return circ + + +def constModMul(qa, const_num, M, qs1, qs2, qs3): + circ = QCircuit() + qnum = len(qa) + + for i in range(0, qnum): + tmp_circ = QCircuit() + tmp = const_num * pow(2, i) % M + tmp_circ.insert(constModAdd(qs1, tmp, M, qs2, qs3)) + tmp_circ = tmp_circ.control([qa[i]]) + circ.insert(tmp_circ) + + for i in range(0, qnum): + circ.insert(CNOT(qa[i], qs1[i]))\ + .insert(CNOT(qs1[i], qa[i]))\ + .insert(CNOT(qa[i], qs1[i])) + + Crev = modReverse(const_num, M) + + tmp_circ1 = QCircuit() + tmp_circ2 = QCircuit() + + for i in range(0, qnum): + tmp = Crev * pow(2, i) % M + tmp_circ1 = QCircuit() + tmp_circ1.insert(constModAdd(qs1, tmp, M, qs2, qs3)) + tmp_circ1 = tmp_circ1.control([qa[i]]) + tmp_circ2.insert(tmp_circ1) + + circ.insert(tmp_circ2.dagger()) + + return circ + + +def constModExp(qa, qb, base, M, qs1, qs2, qs3): + circ = QCircuit() + qnum = len(qa) + tmp = base + + for i in range(0, qnum): + circ.insert(constModMul(qb, tmp, M, qs1, qs2, qs3).control([qa[i]])) + tmp = tmp * tmp % M + + return circ + + +def qft(qlist): + circ = QCircuit() + qnum = len(qlist) + for i in range(0, qnum): + circ.insert(H(qlist[qnum-1-i])) + for j in range(i+1, qnum): + circ.insert( + CR(qlist[qnum-1-j], qlist[qnum-1-i], math.pi/(1 << (j-i)))) + + for i in range(0, qnum//2): + circ.insert(CNOT(qlist[i], qlist[qnum - 1 - i])) + circ.insert(CNOT(qlist[qnum - 1 - i], qlist[i])) + circ.insert(CNOT(qlist[i], qlist[qnum - 1 - i])) + + return circ + + +def gcd(m, n): + if not n: + return m + else: + return gcd(n, m % n) + + +def modReverse(c, m): + if (c == 0): + raise RecursionError('c is zero!') + + if (c == 1): + return 1 + + m1 = m + quotient = [] + quo = m // c + remainder = m % c + + quotient.append(quo) + + while (remainder != 1): + m = c + c = remainder + quo = m // c + remainder = m % c + quotient.append(quo) + + if (len(quotient) == 1): + return m - quo + + if (len(quotient) == 2): + return 1 + quotient[0] * quotient[1] + + rev1 = 1 + rev2 = quotient[-1] + reverse_list = quotient[0:-1] + reverse_list.reverse() + for i in reverse_list: + rev1 = rev1 + rev2 * i + temp = rev1 + rev1 = rev2 + rev2 = temp + + if ((len(quotient) % 2) == 0): + return rev2 + + return m1 - rev2 + + +def plotBar(xdata, ydata): + fig, ax = plt.subplots() + fig.set_size_inches(6, 6) + fig.set_dpi(100) + + rects = ax.bar(xdata, ydata, color='b') + + for rect in rects: + height = rect.get_height() + plt.text(rect.get_x() + rect.get_width() / 2, height, + str(height), ha="center", va="bottom") + + plt.title("Origin Q", loc='right', alpha=0.5) + plt.ylabel('Times') + plt.xlabel('States') + + plt.show() + + +def reorganizeData(measure_qubits, quick_meausre_result): + xdata = [] + ydata = [] + + for i in quick_meausre_result: + xdata.append(i) + ydata.append(quick_meausre_result[i]) + + return xdata, ydata + + +def shorAlg(base, M): + if ((base < 2) or (base > M - 1)): + raise('Invalid base!') + + if (gcd(base, M) != 1): + raise('Invalid base! base and M must be mutually prime') + + binary_len = 0 + while M >> binary_len != 0: + binary_len = binary_len + 1 + + machine = init_quantum_machine(QMachineType.CPU_SINGLE_THREAD) + + qa = machine.qAlloc_many(binary_len*2) + qb = machine.qAlloc_many(binary_len) + + qs1 = machine.qAlloc_many(binary_len) + qs2 = machine.qAlloc_many(binary_len) + qs3 = machine.qAlloc_many(2) + + prog = QProg() + + prog.insert(X(qb[0])) + prog.insert(single_gate_apply_to_all(H, qa)) + prog.insert(constModExp(qa, qb, base, M, qs1, qs2, qs3)) + prog.insert(qft(qa).dagger()) + + directly_run(prog) + result = quick_measure(qa, 100) + + print(result) + + xdata, ydata = reorganizeData(qa, result) + plotBar(xdata, ydata) + + return result + + +if __name__ == "__main__": + base = 7 + N = 15 + shorAlg(base, N)