55import sys
66from io import BytesIO
77
8+ import numpy as np
89import requests
910from PIL import Image
1011
1314
1415class DonutProcessor :
1516 def __init__ (self , model_path = "naver-clova-ix/donut-base-finetuned-cord-v2" ):
16-
1717 self .ensure_installed ("torch" )
1818 self .ensure_installed ("transformers" )
1919
@@ -36,13 +36,31 @@ def ensure_installed(self, package_name):
3636 [sys .executable , "-m" , "pip" , "install" , package_name ]
3737 )
3838
39- async def parse_image (self , image : Image ) -> str :
39+ def preprocess_image (self , image : Image .Image ) -> np .ndarray :
40+ # Convert to RGB if the image is not already in RGB mode
41+ if image .mode != "RGB" :
42+ image = image .convert ("RGB" )
43+
44+ # Convert to numpy array
45+ image_np = np .array (image )
46+
47+ # Ensure the image is 3D (height, width, channels)
48+ if image_np .ndim == 2 :
49+ image_np = np .expand_dims (image_np , axis = - 1 )
50+ image_np = np .repeat (image_np , 3 , axis = - 1 )
51+
52+ return image_np
53+
54+ async def parse_image (self , image : Image .Image ) -> str :
4055 """Process w/ DonutProcessor and VisionEncoderDecoderModel"""
56+ # Preprocess the image
57+ image_np = self .preprocess_image (image )
58+
4159 task_prompt = "<s_cord-v2>"
4260 decoder_input_ids = self .processor .tokenizer (
4361 task_prompt , add_special_tokens = False , return_tensors = "pt"
4462 ).input_ids
45- pixel_values = self .processor (image , return_tensors = "pt" ).pixel_values
63+ pixel_values = self .processor (images = image_np , return_tensors = "pt" ).pixel_values
4664
4765 outputs = self .model .generate (
4866 pixel_values .to (self .device ),
@@ -71,7 +89,7 @@ def process_url(self, url: str) -> str:
7189 image = self .downloader .download_image (url )
7290 return self .parse_image (image )
7391
74- def download_image (self , url : str ) -> Image :
92+ def download_image (self , url : str ) -> Image . Image :
7593 """Download an image from URL."""
7694 response = requests .get (url )
7795 image = Image .open (BytesIO (response .content ))
0 commit comments