Skip to content

Commit 2b46943

Browse files
authored
Add GOT-OCR 2.0 to Transformers (huggingface#34721)
* init modular got_ocr2 * Get correct got_ocr architecture * add processing * run modular with processing * add working inference * apply modular * Refactor and fix style * Refactor, cleanup, fix style * fix init order * Fix docs * add base modeling tests * fix style and consistency * rename doc file * fix repo consistency * fix inference with box * add image processing and support for crop_to_multi_page * Fix batch inference * add tests * fixup * fix slow test * fix docstrings * Add model doc * update to new init * fix input autocast pixel_values dtype * update doc * move doc to multimodal * Reformat crop_image_to_patches and add docstrings * Fix example in forward docstring * Address Pablo review * [run slow] got_ocr2 * remove defaults defined twice * apply modular * add torch_device to integration tests * update modular * follow-up Pavel review * add device variable in doc * fix doc multi-page * Force eager attention for vision encoder to avoid attn implementation conflict * revert qwen2vl doc changes * use Qwen2ForCausalLM instead of Qwen2Model * make fixup * refactor gotocr2 to llava style * uniformize function names and reduce checks * final nits * fix pixel_values dtype error * change checkpoint names * fix modular
1 parent 5bbee12 commit 2b46943

26 files changed

+4184
-3
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,8 @@
872872
title: FLAVA
873873
- local: model_doc/git
874874
title: GIT
875+
- local: model_doc/got_ocr2
876+
title: GOT-OCR2
875877
- local: model_doc/grounding-dino
876878
title: Grounding DINO
877879
- local: model_doc/groupvit

docs/source/en/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ Flax), PyTorch, and/or TensorFlow.
161161
| [GIT](model_doc/git) ||||
162162
| [GLM](model_doc/glm) ||||
163163
| [GLPN](model_doc/glpn) ||||
164+
| [GOT-OCR2](model_doc/got_ocr2) ||||
164165
| [GPT Neo](model_doc/gpt_neo) ||||
165166
| [GPT NeoX](model_doc/gpt_neox) ||||
166167
| [GPT NeoX Japanese](model_doc/gpt_neox_japanese) ||||

docs/source/en/model_doc/got_ocr2.md

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
<!--Copyright 2024 StepFun and The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# GOT-OCR2
18+
19+
## Overview
20+
21+
The GOT-OCR2 model was proposed in [General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model](https://arxiv.org/abs/2409.01704) by Haoran Wei, Chenglong Liu, Jinyue Chen, Jia Wang, Lingyu Kong, Yanming Xu, Zheng Ge, Liang Zhao, Jianjian Sun, Yuang Peng, Chunrui Han, Xiangyu Zhang.
22+
23+
The abstract from the paper is the following:
24+
25+
*Traditional OCR systems (OCR-1.0) are increasingly unable to meet people’snusage due to the growing demand for intelligent processing of man-made opticalncharacters. In this paper, we collectively refer to all artificial optical signals (e.g., plain texts, math/molecular formulas, tables, charts, sheet music, and even geometric shapes) as "characters" and propose the General OCR Theory along with an excellent model, namely GOT, to promote the arrival of OCR-2.0. The GOT, with 580M parameters, is a unified, elegant, and end-to-end model, consisting of a high-compression encoder and a long-contexts decoder. As an OCR-2.0 model, GOT can handle all the above "characters" under various OCR tasks. On the input side, the model supports commonly used scene- and document-style images in slice and whole-page styles. On the output side, GOT can generate plain or formatted results (markdown/tikz/smiles/kern) via an easy prompt. Besides, the model enjoys interactive OCR features, i.e., region-level recognition guided by coordinates or colors. Furthermore, we also adapt dynamic resolution and multipage OCR technologies to GOT for better practicality. In experiments, we provide sufficient results to prove the superiority of our model.*
26+
27+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/got_ocr_overview.png"
28+
alt="drawing" width="600"/>
29+
30+
<small> GOT-OCR2 training stages. Taken from the <a href="https://arxiv.org/abs/2409.01704">original paper.</a> </small>
31+
32+
33+
Tips:
34+
35+
GOT-OCR2 works on a wide range of tasks, including plain document OCR, scene text OCR, formatted document OCR, and even OCR for tables, charts, mathematical formulas, geometric shapes, molecular formulas and sheet music. While this implementation of the model will only output plain text, the outputs can be further processed to render the desired format, with packages like `pdftex`, `mathpix`, `matplotlib`, `tikz`, `verovio` or `pyecharts`.
36+
The model can also be used for interactive OCR, where the user can specify the region to be recognized by providing the coordinates or the color of the region's bounding box.
37+
38+
This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
39+
The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2.0).
40+
41+
## Usage example
42+
43+
### Plain text inference
44+
45+
```python
46+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
47+
48+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
49+
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
50+
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
51+
52+
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
53+
>>> inputs = processor(image, return_tensors="pt").to(device)
54+
55+
>>> generate_ids = model.generate(
56+
... **inputs,
57+
... do_sample=False,
58+
... tokenizer=processor.tokenizer,
59+
... stop_strings="<|im_end|>",
60+
... max_new_tokens=4096,
61+
... )
62+
63+
>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
64+
"R&D QUALITY IMPROVEMENT\nSUGGESTION/SOLUTION FORM\nName/Phone Ext. : (...)"
65+
```
66+
67+
### Plain text inference batched
68+
69+
```python
70+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
71+
72+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
73+
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
74+
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
75+
76+
>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
77+
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
78+
79+
>>> inputs = processor([image1, image2], return_tensors="pt").to(device)
80+
81+
>>> generate_ids = model.generate(
82+
... **inputs,
83+
... do_sample=False,
84+
... tokenizer=processor.tokenizer,
85+
... stop_strings="<|im_end|>",
86+
... max_new_tokens=4,
87+
... )
88+
89+
>>> processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
90+
["Reducing the number", "R&D QUALITY"]
91+
```
92+
93+
### Formatted text inference
94+
95+
GOT-OCR2 can also generate formatted text, such as markdown or LaTeX. Here is an example of how to generate formatted text:
96+
97+
```python
98+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
99+
100+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
101+
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
102+
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
103+
104+
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png"
105+
>>> inputs = processor(image, return_tensors="pt", format=True).to(device)
106+
107+
>>> generate_ids = model.generate(
108+
... **inputs,
109+
... do_sample=False,
110+
... tokenizer=processor.tokenizer,
111+
... stop_strings="<|im_end|>",
112+
... max_new_tokens=4096,
113+
... )
114+
115+
>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
116+
"\\author{\nHanwen Jiang* \\(\\quad\\) Arjun Karpur \\({ }^{\\dagger} \\quad\\) Bingyi Cao \\({ }^{\\dagger} \\quad\\) (...)"
117+
```
118+
119+
### Inference on multiple pages
120+
121+
Although it might be reasonable in most cases to use a “for loop” for multi-page processing, some text data with formatting across several pages make it necessary to process all pages at once. GOT introduces a multi-page OCR (without “for loop”) feature, where multiple pages can be processed by the model at once, whith the output being one continuous text.
122+
Here is an example of how to process multiple pages at once:
123+
124+
125+
```python
126+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
127+
128+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
129+
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
130+
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
131+
132+
>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png"
133+
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png"
134+
>>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True).to(device)
135+
136+
>>> generate_ids = model.generate(
137+
... **inputs,
138+
... do_sample=False,
139+
... tokenizer=processor.tokenizer,
140+
... stop_strings="<|im_end|>",
141+
... max_new_tokens=4096,
142+
... )
143+
144+
>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
145+
"\\title{\nGeneral OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model\n}\n\\author{\nHaoran Wei (...)"
146+
```
147+
148+
### Inference on cropped patches
149+
150+
GOT supports a 1024×1024 input resolution, which is sufficient for most OCR tasks, such as scene OCR or processing A4-sized PDF pages. However, certain scenarios, like horizontally stitched two-page PDFs commonly found in academic papers or images with unusual aspect ratios, can lead to accuracy issues when processed as a single image. To address this, GOT can dynamically crop an image into patches, process them all at once, and merge the results for better accuracy with such inputs.
151+
Here is an example of how to process cropped patches:
152+
153+
```python
154+
>>> import torch
155+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
156+
157+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
158+
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16, device_map=device)
159+
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
160+
161+
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png"
162+
>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3).to(device)
163+
164+
>>> generate_ids = model.generate(
165+
... **inputs,
166+
... do_sample=False,
167+
... tokenizer=processor.tokenizer,
168+
... stop_strings="<|im_end|>",
169+
... max_new_tokens=4096,
170+
... )
171+
172+
>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
173+
"on developing architectural improvements to make learnable matching methods generalize.\nMotivated by the above observations, (...)"
174+
```
175+
176+
### Inference on a specific region
177+
178+
GOT supports interactive OCR, where the user can specify the region to be recognized by providing the coordinates or the color of the region's bounding box. Here is an example of how to process a specific region:
179+
180+
```python
181+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
182+
183+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
184+
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
185+
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
186+
187+
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
188+
>>> inputs = processor(image, return_tensors="pt", color="green").to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels)
189+
190+
>>> generate_ids = model.generate(
191+
... **inputs,
192+
... do_sample=False,
193+
... tokenizer=processor.tokenizer,
194+
... stop_strings="<|im_end|>",
195+
... max_new_tokens=4096,
196+
... )
197+
198+
>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
199+
"You should keep in mind what features from the module should be used, especially \nwhen you’re planning to sell a template."
200+
```
201+
202+
### Inference on general OCR data example: sheet music
203+
204+
Although this implementation of the model will only output plain text, the outputs can be further processed to render the desired format, with packages like `pdftex`, `mathpix`, `matplotlib`, `tikz`, `verovio` or `pyecharts`.
205+
Here is an example of how to process sheet music:
206+
207+
```python
208+
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
209+
>>> import verovio
210+
211+
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
212+
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
213+
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")
214+
215+
>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png"
216+
>>> inputs = processor(image, return_tensors="pt", format=True).to(device)
217+
218+
>>> generate_ids = model.generate(
219+
... **inputs,
220+
... do_sample=False,
221+
... tokenizer=processor.tokenizer,
222+
... stop_strings="<|im_end|>",
223+
... max_new_tokens=4096,
224+
... )
225+
226+
>>> outputs = processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
227+
>>> tk = verovio.toolkit()
228+
>>> tk.loadData(outputs)
229+
>>> tk.setOptions(
230+
... {
231+
... "pageWidth": 2100,
232+
... "pageHeight": 800,
233+
... "footer": "none",
234+
... "barLineWidth": 0.5,
235+
... "beamMaxSlope": 15,
236+
... "staffLineWidth": 0.2,
237+
... "spacingStaff": 6,
238+
... }
239+
... )
240+
>>> tk.getPageCount()
241+
>>> svg = tk.renderToSVG()
242+
>>> svg = svg.replace('overflow="inherit"', 'overflow="visible"')
243+
>>> with open("output.svg", "w") as f:
244+
>>> f.write(svg)
245+
```
246+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sheet_music.svg"
247+
alt="drawing" width="600"/>
248+
249+
## GotOcr2Config
250+
251+
[[autodoc]] GotOcr2Config
252+
253+
## GotOcr2VisionConfig
254+
255+
[[autodoc]] GotOcr2VisionConfig
256+
257+
## GotOcr2ImageProcessor
258+
259+
[[autodoc]] GotOcr2ImageProcessor
260+
261+
## GotOcr2Processor
262+
263+
[[autodoc]] GotOcr2Processor
264+
265+
## GotOcr2ForConditionalGeneration
266+
267+
[[autodoc]] GotOcr2ForConditionalGeneration
268+
- forward
269+

docs/source/en/perf_infer_gpu_one.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ FlashAttention-2 is currently supported for the following architectures:
5252
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
5353
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
5454
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
55+
* [GotOcr2](https://huggingface.co/docs/transformers/model_doc/got_ocr2#transformers.GotOcr2ForConditionalGeneration)
5556
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
5657
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
5758
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
@@ -253,6 +254,7 @@ For now, Transformers supports SDPA inference and training for the following arc
253254
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
254255
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
255256
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
257+
* [GotOcr2](https://huggingface.co/docs/transformers/model_doc/got_ocr2#transformers.GotOcr2ForConditionalGeneration)
256258
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
257259
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
258260
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)

src/transformers/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,11 @@
476476
],
477477
"models.glm": ["GlmConfig"],
478478
"models.glpn": ["GLPNConfig"],
479+
"models.got_ocr2": [
480+
"GotOcr2Config",
481+
"GotOcr2Processor",
482+
"GotOcr2VisionConfig",
483+
],
479484
"models.gpt2": [
480485
"GPT2Config",
481486
"GPT2Tokenizer",
@@ -1238,6 +1243,7 @@
12381243
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
12391244
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
12401245
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
1246+
_import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"])
12411247
_import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
12421248
_import_structure["models.idefics"].extend(["IdeficsImageProcessor"])
12431249
_import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"])
@@ -2426,6 +2432,12 @@
24262432
"GLPNPreTrainedModel",
24272433
]
24282434
)
2435+
_import_structure["models.got_ocr2"].extend(
2436+
[
2437+
"GotOcr2ForConditionalGeneration",
2438+
"GotOcr2PreTrainedModel",
2439+
]
2440+
)
24292441
_import_structure["models.gpt2"].extend(
24302442
[
24312443
"GPT2DoubleHeadsModel",
@@ -5540,6 +5552,7 @@
55405552
)
55415553
from .models.glm import GlmConfig
55425554
from .models.glpn import GLPNConfig
5555+
from .models.got_ocr2 import GotOcr2Config, GotOcr2Processor, GotOcr2VisionConfig
55435556
from .models.gpt2 import (
55445557
GPT2Config,
55455558
GPT2Tokenizer,
@@ -6342,6 +6355,7 @@
63426355
)
63436356
from .models.fuyu import FuyuImageProcessor, FuyuProcessor
63446357
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
6358+
from .models.got_ocr2 import GotOcr2ImageProcessor
63456359
from .models.grounding_dino import GroundingDinoImageProcessor
63466360
from .models.idefics import IdeficsImageProcessor
63476361
from .models.idefics2 import Idefics2ImageProcessor
@@ -7346,6 +7360,10 @@
73467360
GLPNModel,
73477361
GLPNPreTrainedModel,
73487362
)
7363+
from .models.got_ocr2 import (
7364+
GotOcr2ForConditionalGeneration,
7365+
GotOcr2PreTrainedModel,
7366+
)
73497367
from .models.gpt2 import (
73507368
GPT2DoubleHeadsModel,
73517369
GPT2ForQuestionAnswering,

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
git,
107107
glm,
108108
glpn,
109+
got_ocr2,
109110
gpt2,
110111
gpt_bigcode,
111112
gpt_neo,

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@
124124
("git", "GitConfig"),
125125
("glm", "GlmConfig"),
126126
("glpn", "GLPNConfig"),
127+
("got_ocr2", "GotOcr2Config"),
127128
("gpt-sw3", "GPT2Config"),
128129
("gpt2", "GPT2Config"),
129130
("gpt_bigcode", "GPTBigCodeConfig"),
@@ -450,6 +451,7 @@
450451
("git", "GIT"),
451452
("glm", "GLM"),
452453
("glpn", "GLPN"),
454+
("got_ocr2", "GOT-OCR2"),
453455
("gpt-sw3", "GPT-Sw3"),
454456
("gpt2", "OpenAI GPT-2"),
455457
("gpt_bigcode", "GPTBigCode"),

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
("fuyu", ("FuyuImageProcessor",)),
8888
("git", ("CLIPImageProcessor",)),
8989
("glpn", ("GLPNImageProcessor",)),
90+
("got_ocr2", ("GotOcr2ImageProcessor",)),
9091
("grounding-dino", ("GroundingDinoImageProcessor",)),
9192
("groupvit", ("CLIPImageProcessor",)),
9293
("hiera", ("BitImageProcessor",)),

0 commit comments

Comments
 (0)