Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add contracts endpoint #31

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
Relationship,
SQLModel,
UniqueConstraint,
col,
select,
)
from sqlmodel.sql._expression_select_cls import SelectBase

from app.datasources.db.utils import get_md5_abi_hash
moisses89 marked this conversation as resolved.
Show resolved Hide resolved


class SqlQueryBase:
Expand Down Expand Up @@ -38,16 +42,20 @@ class AbiSource(SqlQueryBase, SQLModel, table=True):

class Abi(SqlQueryBase, SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
abi_hash: bytes = Field(nullable=False, index=True, unique=True)
abi_hash: bytes | None = Field(nullable=False, index=True, unique=True)
Uxio0 marked this conversation as resolved.
Show resolved Hide resolved
relevance: int | None = Field(nullable=False, default=0)
abi_json: dict = Field(default_factory=dict, sa_column=Column(JSON))
source_id: int | None = Field(
nullable=True, default=None, foreign_key="abisource.id"
nullable=False, default=None, foreign_key="abisource.id"
)

source: AbiSource | None = Relationship(back_populates="abis")
contracts: list["Contract"] = Relationship(back_populates="abi")

async def create(self, session):
moisses89 marked this conversation as resolved.
Show resolved Hide resolved
self.abi_hash = get_md5_abi_hash(self.abi_json)
return await self._save(session)


class Project(SqlQueryBase, SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
Expand All @@ -72,9 +80,32 @@ class Contract(SqlQueryBase, SQLModel, table=True):
abi_id: bytes | None = Field(
nullable=True, default=None, foreign_key="abi.abi_hash"
)
abi: Abi | None = Relationship(back_populates="contracts")
abi: Abi | None = Relationship(
back_populates="contracts", sa_relationship_kwargs={"lazy": "joined"}
)
project_id: int | None = Field(
nullable=True, default=None, foreign_key="project.id"
)
project: Project | None = Relationship(back_populates="contracts")
project: Project | None = Relationship(
back_populates="contracts", sa_relationship_kwargs={"lazy": "joined"}
)
chain_id: int = Field(default=None)

@classmethod
def get_contracts_query(
cls, address: bytes, chain_ids: list[int] | None = None
) -> SelectBase["Contract"]:
"""
Return a statement to get contracts for the provided address and chain_id

:param address:
:param chain_ids: list of chain_ids, None for all chains
:return:
"""
query = select(cls).where(cls.address == address)
if chain_ids:
query = query.where(col(cls.chain_id).in_(chain_ids)).order_by(
col(cls.chain_id).desc()
)

return query
9 changes: 9 additions & 0 deletions app/datasources/db/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import hashlib
import json


def get_md5_abi_hash(abi: list[dict] | dict) -> bytes:
json_str = json.dumps(abi, sort_keys=True)
md5_hash = hashlib.md5(json_str.encode("utf-8")).hexdigest()
abi_hash = md5_hash[-8:]
return bytes.fromhex(abi_hash)
5 changes: 0 additions & 5 deletions app/models.py

This file was deleted.

3 changes: 2 additions & 1 deletion app/routers/about.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import APIRouter

from app.routers.models import About
moisses89 marked this conversation as resolved.
Show resolved Hide resolved

from .. import VERSION
from ..models import About

router = APIRouter(
prefix="/about",
Expand Down
22 changes: 17 additions & 5 deletions app/routers/contracts.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from typing import Sequence
from typing import Annotated

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, Query, Request

from hexbytes import HexBytes
from safe_eth.eth.utils import fast_is_checksum_address
from sqlmodel.ext.asyncio.session import AsyncSession

from ..datasources.db.database import get_database_session
from ..datasources.db.models import Contract
from ..services.contract import ContractService
from ..services.pagination import PaginatedResponse, PaginationParams
from .models import ContractsPublic

router = APIRouter(
prefix="/contracts",
tags=["contracts"],
)


@router.get("", response_model=Sequence[Contract])
@router.get("/{address}", response_model=PaginatedResponse[ContractsPublic])
async def list_contracts(
request: Request,
address: str,
pagination_params: PaginationParams = Depends(),
chain_ids: Annotated[list[int] | None, Query()] = None,
moisses89 marked this conversation as resolved.
Show resolved Hide resolved
session: AsyncSession = Depends(get_database_session),
) -> Sequence[Contract]:
return await ContractService.get_all(session)
) -> PaginatedResponse[Contract]:
if not fast_is_checksum_address(address):
raise HTTPException(status_code=400, detail="Address is not checksumed")

contracts_service = ContractService(request)
return await contracts_service.get_contract(session, HexBytes(address), chain_ids)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the request to the service and instantiating one service for every client does not sound right to me

61 changes: 61 additions & 0 deletions app/routers/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pydantic import BaseModel, field_validator

from safe_eth.eth.utils import ChecksumAddress, fast_to_checksum_address


class About(BaseModel):
version: str


class ProjectPublic(BaseModel):
description: str
logo_file: str

class Config:
from_attributes = True


class AbiPublic(BaseModel):
abi_json: list[dict] | dict | None
abi_hash: bytes | str

class Config:
from_attributes = True

@field_validator("abi_hash")
@classmethod
def convert_bytes_to_hex(cls, abi_hash):
"""
Convert bytes to hex

:param abi_hash:
:return:
"""
if isinstance(abi_hash, bytes):
return "0x" + abi_hash.hex() # Convert bytes to a hex string
return abi_hash


class ContractsPublic(BaseModel):
address: bytes | ChecksumAddress
name: str
display_name: str | None
chain_id: int
project: ProjectPublic | None
abi: AbiPublic | None

class Config:
from_attributes = True

@field_validator("address")
@classmethod
def convert_to_checksum_address(cls, address):
moisses89 marked this conversation as resolved.
Show resolved Hide resolved
"""
Convert bytes address to checksum address

:param address:
:return:
"""
if isinstance(address, bytes):
return fast_to_checksum_address(address)
return address
24 changes: 23 additions & 1 deletion app/services/contract.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Sequence
from typing import Any, Sequence

from fastapi import Request

from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.models import Contract
from app.services.pagination import GenericPagination, PaginatedResponse

moisses89 marked this conversation as resolved.
Show resolved Hide resolved

class ContractService:

def __init__(self, request: Request):
self.pagination = GenericPagination(request=request)

@staticmethod
async def get_all(session: AsyncSession) -> Sequence[Contract]:
"""
Expand All @@ -16,3 +22,19 @@ async def get_all(session: AsyncSession) -> Sequence[Contract]:
:return:
"""
return await Contract.get_all(session)

async def get_contract(
self, session: AsyncSession, address: bytes, chain_ids: list[int] | None
) -> PaginatedResponse[Any]:
"""
Get the contract by address and/or chain_ids

:param session: database session
:param address: contract address
:param chain_ids: list of filtered chains
:return:
"""

return await self.pagination.paginate(
session, Contract.get_contracts_query(address, chain_ids)
)
90 changes: 90 additions & 0 deletions app/services/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Generic, TypeVar

from fastapi import Query, Request
from pydantic import BaseModel

from sqlalchemy import func
from sqlmodel import select

T = TypeVar("T")


class PaginatedResponse(BaseModel, Generic[T]):
count: int
next: str | None
previous: str | None
results: list[T]


class PaginationParams(BaseModel):
limit: int | None = Query(None, ge=1)
offset: int | None = Query(0, ge=0)


class GenericPagination:
def __init__(
self,
request: Request,
default_page_size: int = 10,
max_page_size: int = 100,
):
self.request = request
self.max_page_size = max_page_size
self.limit = min(
int(self.request.query_params.get("limit", default_page_size)),
max_page_size,
)
self.offset = int(self.request.query_params.get("offset", 0))

def get_next_page(self, count: int) -> str | None:
"""
Calculates the next page of results. If there are no more pages return None

:param base_url:
:param count:
:return:
"""
if self.offset + self.limit < count:
next_offset = self.offset + self.limit
return str(
self.request.url.include_query_params(
limit=self.limit, offset=next_offset
)
)
return None

def get_previous_page(self) -> str | None:
"""
Calculates the previous page of results. If there are no more pages return None

:param base_url:
:return:
"""
if self.offset > 0:
prev_offset = max(0, self.offset - self.limit) # Prevent negative offset
return str(
self.request.url.include_query_params(
limit=self.limit, offset=prev_offset
)
)
return None

async def paginate(self, session, query) -> PaginatedResponse:
"""
Get the paginated response for the provided query

:param session:
:param query:
:param ResponseSchema:
:return:
"""
queryset = await session.exec(query.offset(self.offset).limit(self.limit))
count_query = await session.exec(select(func.count()).where(query._whereclause))
count = count_query.one()
paginated_response = PaginatedResponse(
count=count,
next=self.get_next_page(count),
previous=self.get_previous_page(),
results=queryset.all(),
)
return paginated_response
9 changes: 7 additions & 2 deletions app/tests/db/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class TestModel(DbAsyncConn):
async def test_contract(self, session: AsyncSession):
contract = Contract(address=b"a", name="A test contract", chain_id=1)
await contract.create(session)
await contract.create(session)
result = await contract.get_all(session)
self.assertEqual(result[0], contract)

Expand All @@ -23,7 +22,13 @@ async def test_project(self, session: AsyncSession):

@database_session
async def test_abi(self, session: AsyncSession):
abi = Abi(abi_hash=b"A Test Abi", abi_json={"name": "A Test Project"})
abi_source = AbiSource(name="A Test Source", url="https://test.com")
await abi_source.create(session)
abi = Abi(
abi_hash=b"A Test Abi",
abi_json={"name": "A Test Project"},
source_id=abi_source.id,
)
await abi.create(session)
result = await abi.get_all(session)
self.assertEqual(result[0], abi)
Expand Down
Empty file added app/tests/mocks/__init__.py
Empty file.
Loading
Loading