This repository has been archived by the owner on Nov 11, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Deploy to hugging face
- Loading branch information
Showing
15 changed files
with
385 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
name: Sync to Hugging Face hub | ||
on: | ||
push: | ||
branches: [main] | ||
|
||
# to run this workflow manually from the Actions tab | ||
workflow_dispatch: | ||
|
||
jobs: | ||
sync-to-hub: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
with: | ||
fetch-depth: 0 | ||
lfs: true | ||
- name: Push to hub | ||
env: | ||
HF_TOKEN: ${{ secrets.HF_TOKEN }} | ||
run: | ||
git push https://posix4e:[email protected]/spaces/ttt246/brain main |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
**/*.swp | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,23 @@ | ||
# puppet | ||
--- | ||
title: puppet | ||
colorTo: indigo | ||
app_file: backend/backend.py | ||
sdk: gradio | ||
sdk_version: 2.9.1 | ||
python_version: 3.11.2 | ||
pinned: false | ||
license: mit | ||
--- | ||
|
||
Fun! | ||
|
||
## MVP | ||
|
||
- [ ] Run commands on android | ||
- [ ] Fix assist on android | ||
- [ ] Testing | ||
- [ ] Docs | ||
|
||
## Todo Soon | ||
|
||
- [ ] Other clients (browser extension/watch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,144 +1,180 @@ | ||
from sqlalchemy import Column, Integer, String, DateTime, create_engine, Table, MetaData | ||
from sqlalchemy.sql import text, select, insert | ||
from sqlalchemy.orm import sessionmaker, Session | ||
from sqlalchemy.ext.declarative import declarative_base | ||
from datetime import datetime | ||
import uuid | ||
from fastapi import FastAPI, HTTPException | ||
from pydantic import BaseModel | ||
from firebase_admin import credentials, messaging, initialize_app, auth | ||
import openai | ||
import asyncio | ||
import gradio as gr | ||
from fastapi.middleware.wsgi import WSGIMiddleware | ||
from fastapi.staticfiles import StaticFiles | ||
import uvicorn | ||
from dotenv import load_dotenv | ||
import requests | ||
from uvicorn import Config, Server | ||
import os | ||
from typing import Optional | ||
import gradio as gr | ||
from enum import Enum | ||
|
||
Base = declarative_base() | ||
|
||
|
||
class User(Base): | ||
__tablename__ = "user_data" | ||
|
||
id = Column(Integer, primary_key=True, autoincrement=True) | ||
uuid = Column(String, nullable=False) | ||
openai_key = Column(String) | ||
last_assist = Column(DateTime) | ||
|
||
|
||
engine = create_engine("sqlite:///users.db") | ||
Base.metadata.create_all(bind=engine) | ||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | ||
|
||
load_dotenv() | ||
app = FastAPI(debug=True) | ||
|
||
|
||
class EventItem(BaseModel): | ||
uid: str | ||
event: str | ||
|
||
|
||
@app.post("/send_event") | ||
async def send_event(item: EventItem): | ||
print(f"Received event from {item.uid}:\n{item.event}") | ||
|
||
app = FastAPI() | ||
with open(f"{item.uid}_events.txt", "a") as f: | ||
f.write(f"{datetime.now()} - {item.event}\n") | ||
|
||
# User credentials will be stored in this dictionary | ||
user_data = {} | ||
return {"message": "Event received"} | ||
|
||
|
||
class RegisterItem(BaseModel): | ||
apiKey: str | ||
authDomain: str | ||
databaseURL: str | ||
storageBucket: str | ||
openai_key: str | ||
|
||
|
||
@app.post("/register") | ||
async def register(item: RegisterItem): | ||
# Firebase initialization with user-specific credentials | ||
cred = credentials.Certificate( | ||
{ | ||
"apiKey": item.apiKey, | ||
"authDomain": item.authDomain, | ||
"databaseURL": item.databaseURL, | ||
"storageBucket": item.storageBucket, | ||
} | ||
) | ||
firebase_app = initialize_app(cred, name=str(len(user_data))) | ||
# Add the Firebase app and auth details to the user_data dictionary | ||
user_data[str(len(user_data))] = { | ||
"firebase_app": firebase_app, | ||
"authDomain": item.authDomain, | ||
} | ||
return {"uid": str(len(user_data) - 1)} # Return the user ID | ||
db: Session = SessionLocal() | ||
new_user = User(uuid=str(uuid.uuid4()), openai_key=item.openai_key) | ||
db.add(new_user) | ||
db.commit() | ||
db.refresh(new_user) | ||
return {"uid": new_user.uuid} | ||
|
||
|
||
class ProcessItem(BaseModel): | ||
class AssistItem(BaseModel): | ||
uid: str | ||
prompt: str | ||
version: str | ||
|
||
|
||
@app.post("/process_request") | ||
async def process_request(item: ProcessItem): | ||
# Get the user's Firebase app from the user_data dictionary | ||
firebase_app = user_data.get(item.uid, {}).get("firebase_app", None) | ||
authDomain = user_data.get(item.uid, {}).get("authDomain", None) | ||
def generate_quick_completion(prompt, gpt_version): | ||
response = openai.Completion.create( | ||
engine=gpt_version, prompt=prompt, max_tokens=1500 | ||
) | ||
return response | ||
|
||
|
||
if not firebase_app or not authDomain: | ||
@app.post("/assist") | ||
async def assist(item: AssistItem): | ||
db: Session = SessionLocal() | ||
user = db.query(User).filter(User.uuid == item.uid).first() | ||
if not user: | ||
raise HTTPException(status_code=400, detail="Invalid uid") | ||
|
||
# Call OpenAI | ||
response = openai.Completion.create( | ||
engine="text-davinci-002", prompt=item.prompt, max_tokens=150 | ||
) | ||
openai.api_key = user.openai_key | ||
response = generate_quick_completion(item.prompt, item.version) | ||
|
||
# The message data that will be sent to the client | ||
message = messaging.Message( | ||
data={ | ||
"message": response.choices[0].text.strip(), | ||
}, | ||
topic="updates", | ||
app=firebase_app, # Use the user-specific Firebase app | ||
) | ||
# Update the last time assist was called | ||
user.last_assist = datetime.now() | ||
db.commit() | ||
|
||
# Send the message asynchronously | ||
asyncio.run(send_notification(message)) | ||
return response | ||
|
||
return {"message": "Notification sent"} | ||
|
||
def assist_interface(uid, prompt, gpt_version): | ||
response = requests.post( | ||
"http://localhost:8000/assist", | ||
json={"uid": uid, "prompt": prompt, "version": gpt_version}, | ||
) | ||
return response.text | ||
|
||
def send_notification(message): | ||
# Send a message to the devices subscribed to the provided topic. | ||
response = messaging.send(message) | ||
print("Successfully sent message:", response) | ||
|
||
def get_user_interface(uid): | ||
db: Session = SessionLocal() | ||
user = db.query(User).filter(User.uuid == uid).first() | ||
if not user: | ||
return {"message": "No user with this uid found"} | ||
return str(user) | ||
|
||
def gradio_interface(): | ||
def register(apiKey, authDomain, databaseURL, storageBucket): | ||
response = requests.post( | ||
"http://localhost:8000/register", | ||
json={ | ||
"apiKey": apiKey, | ||
"authDomain": authDomain, | ||
"databaseURL": databaseURL, | ||
"storageBucket": storageBucket, | ||
}, | ||
) | ||
return response.json() | ||
|
||
def process_request(uid, prompt): | ||
response = requests.post( | ||
"http://localhost:8000/process_request", json={"uid": uid, "prompt": prompt} | ||
) | ||
return response.json() | ||
def get_assist_interface(): | ||
gpt_version_dropdown = gr.inputs.Dropdown( | ||
label="GPT Version", | ||
choices=["text-davinci-002", "text-davinci-003", "text-davinci-004"], | ||
default="text-davinci-002", | ||
) | ||
|
||
demo = gr.Interface( | ||
fn=[register, process_request], | ||
return gr.Interface( | ||
fn=assist_interface, | ||
inputs=[ | ||
[ | ||
gr.inputs.Textbox(label="apiKey"), | ||
gr.inputs.Textbox(label="authDomain"), | ||
gr.inputs.Textbox(label="databaseURL"), | ||
gr.inputs.Textbox(label="storageBucket"), | ||
], | ||
[gr.inputs.Textbox(label="uid"), gr.inputs.Textbox(label="prompt")], | ||
gr.inputs.Textbox(label="UID", type="text"), | ||
gr.inputs.Textbox(label="Prompt", type="text"), | ||
gpt_version_dropdown, | ||
], | ||
outputs="json", | ||
title="API Explorer", | ||
description="Use this tool to make requests to the Register and Process Request APIs", | ||
outputs="text", | ||
title="OpenAI Text Generation", | ||
description="Generate text using OpenAI's GPT-4 model.", | ||
) | ||
return demo | ||
|
||
|
||
def process_request_interface(uid, prompt): | ||
item = ProcessItem(uid=uid, prompt=prompt) | ||
response = process_request(item) | ||
return response | ||
def get_db_interface(): | ||
return gr.Interface( | ||
fn=get_user_interface, | ||
inputs="text", | ||
outputs="text", | ||
title="Get User Details", | ||
description="Get user details from the database", | ||
) | ||
|
||
|
||
def register_interface(openai_key): | ||
response = requests.post( | ||
"http://localhost:8000/register", | ||
json={"openai_key": openai_key}, | ||
) | ||
return response.json() | ||
|
||
|
||
def get_gradle_interface(): | ||
def get_register_interface(): | ||
return gr.Interface( | ||
fn=process_request_interface, | ||
inputs=[ | ||
gr.inputs.Textbox(label="UID", type="text"), | ||
gr.inputs.Textbox(label="Prompt", type="text"), | ||
], | ||
fn=register_interface, | ||
inputs=[gr.inputs.Textbox(label="OpenAI Key", type="text")], | ||
outputs="text", | ||
title="OpenAI Text Generation", | ||
description="Generate text using OpenAI's GPT-3 model.", | ||
title="Register New User", | ||
description="Register a new user by entering an OpenAI key.", | ||
) | ||
|
||
|
||
app = gr.mount_gradio_app(app, get_gradle_interface(), path="/") | ||
app = gr.mount_gradio_app( | ||
app, | ||
gr.TabbedInterface( | ||
[ | ||
get_assist_interface(), | ||
get_db_interface(), | ||
get_register_interface(), | ||
] | ||
), | ||
path="/", | ||
) | ||
|
||
if __name__ == "__main__": | ||
config = Config("backend:app", host="127.0.0.1", port=8000, reload=True) | ||
server = Server(config) | ||
server.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ pinecone-io | |
pydantic | ||
python-dotenv | ||
swagger-ui-bundle | ||
|
||
sqlalchemy | ||
pytest-asyncio | ||
pytest | ||
requests | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,3 @@ local.properties | |
.idea/assetWizardSettings.xml | ||
.DS_Store | ||
build | ||
local.properties |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.