diff --git a/g4f/Provider/PollinationsAI.py b/g4f/Provider/PollinationsAI.py index 92ea7cdf4cf..f1ad0031301 100644 --- a/g4f/Provider/PollinationsAI.py +++ b/g4f/Provider/PollinationsAI.py @@ -5,7 +5,6 @@ import requests from typing import Optional from aiohttp import ClientSession -from urllib.parse import quote from ..requests.raise_for_status import raise_for_status from ..typing import AsyncResult, Messages @@ -75,11 +74,6 @@ def get_models(cls, **kwargs): # Return combined models return cls.text_models + cls.image_models - @classmethod - def get_model(cls, model: str) -> str: - """Convert model alias to actual model name""" - return cls.model_aliases.get(model, model) - @classmethod async def create_async_generator( cls, @@ -106,10 +100,7 @@ async def create_async_generator( ) -> AsyncResult: model = cls.get_model(model) - # Ensure models are loaded - if not cls.image_models or not cls.text_models: - cls.get_models() - + # Check if models # Image generation if model in cls.image_models: async for result in cls._generate_image( @@ -159,6 +150,7 @@ async def _generate_image( if seed is None: seed = random.randint(0, 10000) + headers = { 'Accept': '*/*', 'Accept-Language': 'en-US,en;q=0.9', @@ -178,15 +170,13 @@ async def _generate_image( params = {k: v for k, v in params.items() if v is not None} async with ClientSession(headers=headers) as session: - # Use the last message's content as the prompt if no specific prompt provided - prompt_text = prompt if prompt else messages[-1]["content"] - encoded_prompt = quote(prompt_text) + prompt = quote(messages[-1]["content"]) param_string = "&".join(f"{k}={v}" for k, v in params.items()) - url = f"{cls.image_api_endpoint}/prompt/{encoded_prompt}?{param_string}" + url = f"{cls.image_api_endpoint}/prompt/{prompt}?{param_string}" async with session.head(url, proxy=proxy) as response: if response.status == 200: - image_response = ImageResponse(images=[url], alt=prompt_text) + image_response = ImageResponse(images=url, alt=messages[-1]["content"]) yield image_response @classmethod @@ -203,7 +193,7 @@ async def _generate_text( stream: bool ) -> AsyncResult: if api_key is None: - api_key = "dummy" + api_key = "dummy" # Default value if api_key is not provided headers = { "accept": "*/*", @@ -229,31 +219,10 @@ async def _generate_text( response.raise_for_status() async for chunk in response.content: if chunk: + decoded_chunk = chunk.decode() try: - decoded_chunk = chunk.decode('utf-8').strip() - # Skip empty lines - if not decoded_chunk: - continue - - # Handle SSE format - if decoded_chunk.startswith('data: '): - decoded_chunk = decoded_chunk[6:] - - # Skip heartbeat messages - if decoded_chunk == '[DONE]': - continue - - try: - json_response = json.loads(decoded_chunk) - if isinstance(json_response, dict): - if 'choices' in json_response: - content = json_response['choices'][0]['message']['content'] - yield content - elif 'content' in json_response: - yield json_response['content'] - except json.JSONDecodeError: - # If it's not JSON, yield the raw chunk - yield decoded_chunk - except Exception as e: - print(f"Error processing chunk: {e}") - continue + json_response = json.loads(decoded_chunk) + content = json_response['choices'][0]['message']['content'] + yield content + except json.JSONDecodeError: + yield decoded_chunk