Skip to content


操作步骤

项目链接

安装步骤

bash
# 安装conda(armv8架构)
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
bash Miniconda3-latest-Linux-aarch64.sh -b -u -p ./miniconda3
source ./miniconda3/bin/activate
conda init --all

# 对于rk系列芯片,可能要安装如下内容
sudo apt-get update
sudo apt-get install libgl1-mesa-glx

# 安装pix2text
pip install pix2text
pip install fastapi[all]

# p2t预测help
p2t predict -h

# p2t直接进行预测
p2t predict -l en,ch_sim --resized-shape 768 --file-type pdf -i docs/examples/test-doc.pdf -o output-md --save-debug-res output-debug

#p2t服务help
p2t serve -h

# 开启http服务
p2t serve -l en,ch_sim -H 0.0.0.0 -p 5040

# 使用curl调用服务,注意保留@
curl -X POST \
  -F "file_type=text_formula" \
  -F "resized_shape=768" \
  -F "embed_sep= $,$ " \
  -F "isolated_sep=$$\n, \n$$" \
  -F "image=@path/to/pic.png;type=image/png" \
  http://0.0.0.0:5040/pix2text

此外,也可以通过python脚本调用服务,参考如下代码:

python
import requests

url = 'http://0.0.0.0:5040/pix2text'
pngnum = input('input the name of the png, omit .png: ')
pngname = str(pngnum + '.png')
image_fp = './'+pngname
data = {
    "file_type": "text_formula", # 还有page,formula,text
    "resized_shape": 768, # 默认值,通常不用改
    "embed_sep": " $,$ ",
    "isolated_sep": "$$\n, \n$$"
}
files = {
    "image": (image_fp, open(image_fp, 'rb'), 'image/png')
}

r = requests.post(url, data=data, files=files)

outs = r.json()['results']
with open('./'+pngnum+'.md', 'w', encoding='utf-8') as file:
    file.write(outs)
print("write success")

基于 Gradio 构建 Pix2Text APP 网页

参考链接

接下来服务器应当处于通畅的网络环境下,服务器架构假设为x86+cuda

bash
pip install gradio numpy huggingface_hub onnxruntime-gpu pix2text

运行以下脚本, 第一次运行需要下载相应模型(Gradio页面中选择),注意选择可以直接下载的模型

python
import os
import random
import shutil
import string
import time
import zipfile
from pathlib import Path

import yaml

import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download

from pix2text import Pix2Text
from pix2text.utils import set_logger

logger = set_logger()

LANGUAGES = yaml.safe_load(open('languages.yaml', 'r', encoding='utf-8'))['languages']
OUTPUT_RESULT_DIR = Path('./output-results')
OUTPUT_RESULT_DIR.mkdir(exist_ok=True)


def prepare_mfd_model():
    target_fp = './yolov7-model/mfd-yolov7-epoch224-20230613.pt'
    if os.path.exists(target_fp):
        return target_fp
    HF_TOKEN = os.environ.get('HF_TOKEN')
    local_path = hf_hub_download(
        repo_id='breezedeus/paid-models',
        subfolder='cnstd/1.2',
        filename='yolov7-model-20230613.zip',
        repo_type="model",
        cache_dir='./',
        token=HF_TOKEN,
    )
    with zipfile.ZipFile(local_path) as zf:
        zf.extractall('./')
    return target_fp


def get_p2t_model(lan_list: list, mfd_model_name: str, mfr_model_name: str):
    mfd_config = dict(model_name=mfd_model_name, model_backend='onnx')  # 声明 MFD 的初始化参数
    formula_config = dict(model_name=mfr_model_name, model_backend='onnx')  # 声明 MFR 的初始化参数
    text_formula_config = dict(
        languages=lan_list, mfd=mfd_config, formula=formula_config,
    )
    total_config = {
        'layout': {'scores_thresh': 0.45},
        'text_formula': text_formula_config,
    }
    p2t = Pix2Text.from_config(total_configs=total_config,)
    return p2t


def latex_render(latex_str):
    return f"$$\n{latex_str}\n$$"
    # return latex_str


def recognize(
    lang_list, mfd_model_name, mfr_model_name, rec_type, resized_shape, image_file
):
    lang_list = [LANGUAGES[l] for l in lang_list]
    p2t = get_p2t_model(lang_list, mfd_model_name, mfr_model_name)

    # 如果 OUTPUT_RESULT_DIR 文件数量超过 100,按时间删除最早的 100 个文件
    if len(os.listdir(OUTPUT_RESULT_DIR)) > 100:
        shutil.rmtree(OUTPUT_RESULT_DIR)
        OUTPUT_RESULT_DIR.mkdir(exist_ok=True)

    out_det_fp = './docs/no-det-res.jpg'
    kwargs = dict(resized_shape=resized_shape, return_text=True, auto_line_break=True,)
    if rec_type == 'page':
        suffix = list(string.ascii_letters)
        random.shuffle(suffix)
        suffix = ''.join(suffix[:6])
        fp_suffix = f'{time.time()}-{suffix}'
        out_debug_dir = f'out-debug-{fp_suffix}'
        output_dir = OUTPUT_RESULT_DIR / f'output-{fp_suffix}'
        kwargs['save_debug_res'] = OUTPUT_RESULT_DIR / out_debug_dir
    elif rec_type == 'text_formula':
        suffix = list(string.ascii_letters)
        random.shuffle(suffix)
        suffix = ''.join(suffix[:6])
        out_det_fp = f'out-det-{time.time()}-{suffix}.jpg'
        kwargs['save_analysis_res'] = str(OUTPUT_RESULT_DIR / out_det_fp)

    out = p2t.recognize(image_file, file_type=rec_type, **kwargs)
    out_text = out
    if rec_type == 'page':
        out_text = out.to_markdown(output_dir)
        out_det_fp = kwargs['save_debug_res'] / 'layout_res.jpg'
    elif rec_type == 'text_formula':
        out_det_fp = kwargs['save_analysis_res']

    return out_text, out_det_fp


def example_func(lang_list, rec_type, resized_shape, image_file):
    return recognize(
        lang_list,
        mfd_model_name='mfd-pro',
        mfr_model_name='mfr-pro',
        rec_type=rec_type,
        resized_shape=resized_shape,
        image_file=image_file,
    )


def main():
    langs = list(LANGUAGES.keys())
    langs.sort(key=lambda x: x.lower())

    with gr.Blocks() as demo:

        with gr.Row(equal_height=False):
            with gr.Column(min_width=200, variant='panel', scale=3):
                gr.Markdown('### Settings')
                lang_list = gr.Dropdown(
                    label='Text Languages',
                    choices=langs,
                    value=['English', 'Chinese Simplified'],
                    multiselect=True,
                )
                mfd_model_name = gr.Dropdown(
                    label='MFD Models',
                    choices=['mfd', 'mfd-advanced', 'mfd-pro'],
                    value='mfd',
                )
                mfr_model_name = gr.Dropdown(
                    label='MFR Models',
                    choices=['mfr', 'mfr-pro', 'mfr-plus'],
                    value='mfr',
                )
                rec_type = gr.Dropdown(
                    label='file_type',
                    choices=['page', 'text_formula', 'formula', 'text'],
                    value='formula',
                    # info='Which type of image to be recognized.',
                )
                with gr.Accordion('More Options', open=False):
                    resized_shape = gr.Slider(
                        label='resized_shape',
                        minimum=512,
                        maximum=2048,
                        value=768,
                        step=32,
                    )

            with gr.Column(scale=6, variant='compact'):
                gr.Markdown('### 上传图片')
                image_file = gr.Image(
                    label='Image', type="pil", image_mode='RGB', show_label=False
                )
                sub_btn = gr.Button("Submit", variant="primary")

        with gr.Row(equal_height=False):
            with gr.Column(scale=1, variant='compact'):
                gr.Markdown('**Detection Result**')
                det_result = gr.Image(
                    label='Detection Result', scale=1, show_label=False
                )
            with gr.Column(scale=1, variant='compact'):
                gr.Markdown(
                    '**输出结果**'
                )
                rec_result = gr.Textbox(
                    label=f'Recognition Result ',
                    lines=5,
                    value='',
                    scale=1,
                    show_label=False,
                    show_copy_button=True,
                )
        sub_btn.click(
            recognize,
            inputs=[
                lang_list,
                mfd_model_name,
                mfr_model_name,
                rec_type,
                resized_shape,
                image_file,
            ],
            outputs=[rec_result, det_result],
        )


    demo.queue(max_size=10)
    demo.launch(share=False,server_name="192.168.10.104",server_port=7860)


if __name__ == '__main__':
    main()