Skip to content

Commit

Permalink
Implement FastAPI Unit Tests (#34)
Browse files Browse the repository at this point in the history
# Description
Adds test coverage for all exposed endpoints with mocked backend methods
Addresses some edge-case bugs in `request.client` object handling
Fixes/completes model type annotations


# Issues
<!-- If this is related to or closes an issue/other PR, please note them
here -->

# Other Notes
<!-- Note any breaking changes, WIP changes, requests for input, etc.
here -->
  • Loading branch information
NeonDaniel authored Nov 19, 2024
1 parent da161ee commit 0223749
Show file tree
Hide file tree
Showing 10 changed files with 594 additions and 13 deletions.
2 changes: 1 addition & 1 deletion neon_hana/app/routers/api_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ async def api_proxy_geolocation(query: GeoAPIReverseRequest) -> GeoAPIReverseRes

@proxy_route.post("/wolframalpha")
async def api_proxy_wolframalpha(query: WolframAlphaAPIRequest) -> WolframAlphaAPIResponse:
return mq_connector.query_api_proxy("wolfram_alpha", dict(query))
return mq_connector.query_api_proxy("wolfram_alpha", dict(query))
3 changes: 2 additions & 1 deletion neon_hana/app/routers/assist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ async def get_tts(request: TTSRequest) -> TTSResponse:
async def get_response(skill_request: SkillRequest,
request: Request) -> SkillResponse:
if not skill_request.node_data.networking.public_ip:
skill_request.node_data.networking.public_ip = request.client.host
host = request.client.host if request.client else ""
skill_request.node_data.networking.public_ip = host
return mq_connector.get_response(**dict(skill_request))
3 changes: 2 additions & 1 deletion neon_hana/app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
@auth_route.post("/login")
async def check_login(auth_request: AuthenticationRequest,
request: Request) -> AuthenticationResponse:
ip_addr = request.client.host if request.client else "127.0.0.1"
return client_manager.check_auth_request(**dict(auth_request),
origin_ip=request.client.host)
origin_ip=ip_addr)


@auth_route.post("/refresh")
Expand Down
23 changes: 20 additions & 3 deletions neon_hana/app/routers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import re

from fastapi import APIRouter, Request
from starlette.responses import PlainTextResponse

Expand All @@ -32,13 +34,28 @@
util_route = APIRouter(prefix="/util", tags=["utilities"])


def _is_ipv4(address: str) -> bool:
ipv4_regex = re.compile(
r'^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01'
r']?[0-9][0-9]?)\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.(?:25[0-5]|'
r'2[0-4][0-9]|[01]?[0-9][0-9]?)$')
return ipv4_regex.match(address)


@util_route.get("/client_ip", response_class=PlainTextResponse)
async def api_client_ip(request: Request) -> str:
client_manager.validate_auth("", request.client.host)
return request.client.host
ip_addr = request.client.host if request.client else "127.0.0.1"

if not _is_ipv4(ip_addr):
# Reported host is a hostname, not an IP address. Return a generic
# loopback value
ip_addr = "127.0.0.1"
client_manager.validate_auth("", ip_addr)
return ip_addr


@util_route.get("/headers")
async def api_headers(request: Request):
client_manager.validate_auth("", request.client.host)
ip_addr = request.client.host if request.client else "127.0.0.1"
client_manager.validate_auth("", ip_addr)
return request.headers
3 changes: 2 additions & 1 deletion neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,9 @@ async def __call__(self, request: Request):
if not credentials.scheme == "Bearer":
raise HTTPException(status_code=403,
detail="Invalid authentication scheme.")
host = request.client.host if request.client else "127.0.0.1"
if not self.client_manager.validate_auth(credentials.credentials,
request.client.host):
host):
raise HTTPException(status_code=403,
detail="Invalid or expired token.")
return credentials.credentials
Expand Down
2 changes: 1 addition & 1 deletion neon_hana/mq_service_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _validate_api_proxy_response(response: dict, query_params: dict):
try:
resp = json.loads(response['content'])
if query_params.get('service') == "alpha_vantage":
resp['service'] = query_params['service']
resp['provider'] = query_params['service']
if query_params.get("region") and resp.get('bestMatches'):
filtered = [
stock for stock in resp.get("bestMatches")
Expand Down
7 changes: 5 additions & 2 deletions neon_hana/schema/api_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class WeatherAPIOnecallResponse(BaseModel):
timezone: str
timezone_offset: int
current: Dict[str, Any]
minutely: Optional[List[dict]]
minutely: Optional[List[dict]] = None
hourly: List[dict]
daily: List[dict]

Expand Down Expand Up @@ -1742,7 +1742,8 @@ class WeatherAPIOnecallResponse(BaseModel):


class StockAPIQuoteResponse(BaseModel):
global_quote: Dict[str, str] = Field(..., alias="Global Quote")
provider: str
global_quote: Dict[str, str] = Field(alias="Global Quote")

model_config = {
"extra": "allow",
Expand All @@ -1767,6 +1768,8 @@ class StockAPIQuoteResponse(BaseModel):


class StockAPISearchResponse(BaseModel):
provider: str
best_matches: List[Dict[str, str]] = Field(alias="bestMatches")
model_config = {
"extra": "allow",
"json_schema_extra": {
Expand Down
6 changes: 3 additions & 3 deletions neon_hana/schema/llm_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import List
from typing import List, Tuple

from pydantic import BaseModel

Expand All @@ -42,11 +42,11 @@ class LLMRequest(BaseModel):

class LLMResponse(BaseModel):
response: str
history: List[tuple]
history: List[Tuple[str, str]]
model_config = {
"json_schema_extra": {
"examples": [{
"query": "I am well, how about you?",
"response": "As a large language model, I do not feel",
"history": [("user", "hello"),
("llm", "Hi, how can I help you today?"),
("user", "I am well, how about you?"),
Expand Down
1 change: 1 addition & 0 deletions requirements/test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pytest
mock
httpx
neon-iris~=0.1
websockets~=12.0
Loading

0 comments on commit 0223749

Please sign in to comment.