Skip to content

Commit

Permalink
[Bugfix] Convert image to RGB by default (vllm-project#6430)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and dtrifiro committed Jul 17, 2024
1 parent 09a8d63 commit af62b48
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def _load_image_from_data_url(image_url: str):
return load_image_from_base64(image_base64)


def fetch_image(image_url: str) -> Image.Image:
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")

Expand All @@ -53,7 +57,7 @@ def fetch_image(image_url: str) -> Image.Image:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")

return image
return image.convert(image_mode)


class ImageFetchAiohttp:
Expand All @@ -70,8 +74,17 @@ def get_aiohttp_client(cls) -> aiohttp.ClientSession:
return cls.aiohttp_client

@classmethod
async def fetch_image(cls, image_url: str) -> Image.Image:
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
async def fetch_image(
cls,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""

if image_url.startswith('http'):
_validate_remote_url(image_url, name="image_url")
Expand All @@ -91,20 +104,27 @@ async def fetch_image(cls, image_url: str) -> Image.Image:
"Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image' or 'http'.")

return image
return image.convert(image_mode)


async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
image = await ImageFetchAiohttp.fetch_image(image_url)
return {"image": image}


def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
"""Encode a pillow image to base64 format."""
def encode_image_base64(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "JPEG",
) -> str:
"""
Encode a pillow image to base64 format.
By default, the image is converted into RGB format before being encoded.
"""
buffered = BytesIO()
if format == 'JPEG':
image = image.convert('RGB')
image = image.convert(image_mode)
image.save(buffered, format)
return base64.b64encode(buffered.getvalue()).decode('utf-8')

Expand Down

0 comments on commit af62b48

Please sign in to comment.