Skip to content

Commit

Permalink
Update g4f/Provider/PollinationsAI.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kqlio67 committed Dec 27, 2024
1 parent 22cc553 commit eebd075
Showing 1 changed file with 12 additions and 43 deletions.
55 changes: 12 additions & 43 deletions g4f/Provider/PollinationsAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import requests
from typing import Optional
from aiohttp import ClientSession
from urllib.parse import quote

from ..requests.raise_for_status import raise_for_status
from ..typing import AsyncResult, Messages
Expand Down Expand Up @@ -75,11 +74,6 @@ def get_models(cls, **kwargs):
# Return combined models
return cls.text_models + cls.image_models

@classmethod
def get_model(cls, model: str) -> str:
"""Convert model alias to actual model name"""
return cls.model_aliases.get(model, model)

@classmethod
async def create_async_generator(
cls,
Expand All @@ -106,10 +100,7 @@ async def create_async_generator(
) -> AsyncResult:
model = cls.get_model(model)

# Ensure models are loaded
if not cls.image_models or not cls.text_models:
cls.get_models()

# Check if models
# Image generation
if model in cls.image_models:
async for result in cls._generate_image(
Expand Down Expand Up @@ -159,6 +150,7 @@ async def _generate_image(
if seed is None:
seed = random.randint(0, 10000)


headers = {
'Accept': '*/*',
'Accept-Language': 'en-US,en;q=0.9',
Expand All @@ -178,15 +170,13 @@ async def _generate_image(
params = {k: v for k, v in params.items() if v is not None}

async with ClientSession(headers=headers) as session:
# Use the last message's content as the prompt if no specific prompt provided
prompt_text = prompt if prompt else messages[-1]["content"]
encoded_prompt = quote(prompt_text)
prompt = quote(messages[-1]["content"])
param_string = "&".join(f"{k}={v}" for k, v in params.items())
url = f"{cls.image_api_endpoint}/prompt/{encoded_prompt}?{param_string}"
url = f"{cls.image_api_endpoint}/prompt/{prompt}?{param_string}"

async with session.head(url, proxy=proxy) as response:
if response.status == 200:
image_response = ImageResponse(images=[url], alt=prompt_text)
image_response = ImageResponse(images=url, alt=messages[-1]["content"])
yield image_response

@classmethod
Expand All @@ -203,7 +193,7 @@ async def _generate_text(
stream: bool
) -> AsyncResult:
if api_key is None:
api_key = "dummy"
api_key = "dummy" # Default value if api_key is not provided

headers = {
"accept": "*/*",
Expand All @@ -229,31 +219,10 @@ async def _generate_text(
response.raise_for_status()
async for chunk in response.content:
if chunk:
decoded_chunk = chunk.decode()
try:
decoded_chunk = chunk.decode('utf-8').strip()
# Skip empty lines
if not decoded_chunk:
continue

# Handle SSE format
if decoded_chunk.startswith('data: '):
decoded_chunk = decoded_chunk[6:]

# Skip heartbeat messages
if decoded_chunk == '[DONE]':
continue

try:
json_response = json.loads(decoded_chunk)
if isinstance(json_response, dict):
if 'choices' in json_response:
content = json_response['choices'][0]['message']['content']
yield content
elif 'content' in json_response:
yield json_response['content']
except json.JSONDecodeError:
# If it's not JSON, yield the raw chunk
yield decoded_chunk
except Exception as e:
print(f"Error processing chunk: {e}")
continue
json_response = json.loads(decoded_chunk)
content = json_response['choices'][0]['message']['content']
yield content
except json.JSONDecodeError:
yield decoded_chunk

0 comments on commit eebd075

Please sign in to comment.