From e8c048e5aa16189f0659fc7d8b6a3112b3305cbf Mon Sep 17 00:00:00 2001 From: piqxznopb <1207967813@qq.com> Date: Mon, 20 Jan 2025 15:07:17 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0Model=5Fset.py=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Models_set.py | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 Models_set.py diff --git a/Models_set.py b/Models_set.py new file mode 100644 index 0000000..ee2b19a --- /dev/null +++ b/Models_set.py @@ -0,0 +1,100 @@ +import requests +import torch +from PIL import Image +from transformers import pipeline +from paddleocr import PaddleOCR +from transformers import BlipProcessor, BlipForConditionalGeneration + + +"""事实证明除中文外,英语转成其他语言更难""" +language_pairs = { + "en-zh": "Helsinki-NLP/opus-mt-en-zh", # 1:英语到中文 + "zh-en": "Helsinki-NLP/opus-mt-zh-en", # 2:中文到英语 + "en-fr": "Helsinki-NLP/opus-mt-en-fr", # 3:英语到法语 + "fr-en": "Helsinki-NLP/opus-mt-fr-en", # 4:法语到英语 + "en-es": "Helsinki-NLP/opus-mt-en-es", # 5:英语到西班牙语 + "es-en": "Helsinki-NLP/opus-mt-es-en", # 6:西班牙语到英语 + "en-de": "Helsinki-NLP/opus-mt-en-de", # 7:英语到德语 + "de-en": "Helsinki-NLP/opus-mt-de-en", # 8:德语到英语 + "en-ru": "Helsinki-NLP/opus-mt-en-ru", # 9:英语到俄语 + "ru-en": "Helsinki-NLP/opus-mt-ru-en", # 10:俄语到英语 + "zh-fr": "Helsinki-NLP/opus-mt-zh-fr", # 11:中文到法语 + "fr-zh": "Helsinki-NLP/opus-mt-fr-zh", # 12:法语到中文 +} + +flag = int(input('请选择要进行的模式(文字识别请输入1,图片理解请输入2):')) +pairs = 1 # 默认为英转中 +# image_path_identify = '测试图片/img_12.jpg' # 用来文字识别的图片所在的文件路径 +# image_path_understand = r"测试图片/birds-6968147_640.jpg" # 用来进行图片理解的文件路径 +while True: + if flag == 1: + image_path_identify = input("请输入测试图片本地路径:") # 用来文字识别的图片所在的文件路径 + if torch.cuda.is_available(): + # 使用 GPU + ocr = PaddleOCR(use_angle_cls=True, lang='ch', use_gpu=True) + print("使用 GPU 进行 OCR") + else: + # 使用 CPU + ocr = PaddleOCR(use_angle_cls=True, lang='ch', use_gpu=False) + print("使用 CPU 进行 OCR") + flag = 0 + result = ocr.ocr(image_path_identify, cls=True) + + # 提取纯文字内容 + extracted_text = [res[1][0] for res in result[0]] + + """-----------------------------------------这是文字识别的输出--------------------------------------------------""" + print("文字识别结果如下:") + for text in extracted_text: + print(text) + + if flag == 2: + image_path_understand = input("请输入测试图片本地路径:") # 用来进行图片理解的文件路径 + # 图片理解模块 + processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") + + image_path = image_path_understand # 本地图片路径 + image = Image.open(image_path) + + # 处理图片 + inputs = processor(images=image, return_tensors="pt") + output = model.generate(**inputs) + + # 解码生成的描述 + description = processor.decode(output[0], skip_special_tokens=True) + print(description) + + if pairs == 1: + language_pair = "en-zh" + if pairs == 2: + language_pair = "zh-en" + if pairs == 3: + language_pair = "en-fr" + if pairs == 4: + language_pair = "fr-en" + if pairs == 5: + language_pair = "en-es" + if pairs == 6: + language_pair = "es-en" + if pairs == 7: + language_pair = "en-de" + if pairs == 8: + language_pair = "de-en" + if pairs == 9: + language_pair = "en-ru" + if pairs == 10: + language_pair = "ru-en" + if pairs == 11: + language_pair = "zh-fr" + if pairs == 12: + language_pair = "fr-zh" + translator = pipeline("translation", model=language_pairs[language_pair]) + + # 翻译文本 + result = translator(description, max_length=40) + translation_text = result[0]['translation_text'] + print(translation_text) + flag_break = input("还要继续吗,退出请按1再回车:") + if flag_break == '1': + break \ No newline at end of file