Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add contracts endpoint #31

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions app/datasources/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,31 @@
Relationship,
SQLModel,
UniqueConstraint,
col,
select,
)
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql._expression_select_cls import SelectBase

from .utils import get_md5_abi_hash


class SqlQueryBase:

@classmethod
async def get_all(cls, session):
async def get_all(cls, session: AsyncSession):
result = await session.exec(select(cls))
return result.all()

async def _save(self, session):
async def _save(self, session: AsyncSession):
session.add(self)
await session.commit()
return self

async def update(self, session):
async def update(self, session: AsyncSession):
return await self._save(session)

async def create(self, session):
async def create(self, session: AsyncSession):
return await self._save(session)


Expand All @@ -38,16 +43,20 @@ class AbiSource(SqlQueryBase, SQLModel, table=True):

class Abi(SqlQueryBase, SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
abi_hash: bytes = Field(nullable=False, index=True, unique=True)
abi_hash: bytes | None = Field(nullable=False, index=True, unique=True)
Uxio0 marked this conversation as resolved.
Show resolved Hide resolved
relevance: int | None = Field(nullable=False, default=0)
abi_json: dict = Field(default_factory=dict, sa_column=Column(JSON))
source_id: int | None = Field(
nullable=True, default=None, foreign_key="abisource.id"
nullable=False, default=None, foreign_key="abisource.id"
)

source: AbiSource | None = Relationship(back_populates="abis")
contracts: list["Contract"] = Relationship(back_populates="abi")

async def create(self, session):
moisses89 marked this conversation as resolved.
Show resolved Hide resolved
self.abi_hash = get_md5_abi_hash(self.abi_json)
return await self._save(session)


class Project(SqlQueryBase, SQLModel, table=True):
id: int | None = Field(default=None, primary_key=True)
Expand All @@ -72,9 +81,32 @@ class Contract(SqlQueryBase, SQLModel, table=True):
abi_id: bytes | None = Field(
nullable=True, default=None, foreign_key="abi.abi_hash"
)
abi: Abi | None = Relationship(back_populates="contracts")
abi: Abi | None = Relationship(
back_populates="contracts", sa_relationship_kwargs={"lazy": "joined"}
)
project_id: int | None = Field(
nullable=True, default=None, foreign_key="project.id"
)
project: Project | None = Relationship(back_populates="contracts")
project: Project | None = Relationship(
back_populates="contracts", sa_relationship_kwargs={"lazy": "joined"}
)
chain_id: int = Field(default=None)

@classmethod
def get_contracts_query(
cls, address: bytes, chain_ids: list[int] | None = None
) -> SelectBase["Contract"]:
"""
Return a statement to get contracts for the provided address and chain_id

:param address:
:param chain_ids: list of chain_ids, None for all chains
:return:
"""
query = select(cls).where(cls.address == address)
if chain_ids:
query = query.where(col(cls.chain_id).in_(chain_ids)).order_by(
col(cls.chain_id).desc()
)

return query
9 changes: 9 additions & 0 deletions app/datasources/db/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import hashlib
import json


def get_md5_abi_hash(abi: list[dict] | dict) -> bytes:
json_str = json.dumps(abi, sort_keys=True)
md5_hash = hashlib.md5(json_str.encode("utf-8")).hexdigest()
abi_hash = md5_hash[-8:]
return bytes.fromhex(abi_hash)
5 changes: 0 additions & 5 deletions app/models.py

This file was deleted.

2 changes: 1 addition & 1 deletion app/routers/about.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter

from .. import VERSION
from ..models import About
from .models import About

router = APIRouter(
prefix="/about",
Expand Down
30 changes: 25 additions & 5 deletions app/routers/contracts.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
from typing import Sequence
from typing import Annotated

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException, Query, Request

from hexbytes import HexBytes
from safe_eth.eth.utils import fast_is_checksum_address
from sqlmodel.ext.asyncio.session import AsyncSession

from ..datasources.db.database import get_database_session
from ..datasources.db.models import Contract
from ..services.contract import ContractService
from ..services.pagination import (
GenericPagination,
PaginatedResponse,
PaginationQueryParams,
)
from .models import ContractsPublic

router = APIRouter(
prefix="/contracts",
tags=["contracts"],
)


@router.get("", response_model=Sequence[Contract])
@router.get("/{address}", response_model=PaginatedResponse[ContractsPublic])
async def list_contracts(
request: Request,
address: str,
pagination_params: PaginationQueryParams = Depends(),
chain_ids: Annotated[list[int] | None, Query()] = None,
session: AsyncSession = Depends(get_database_session),
) -> Sequence[Contract]:
return await ContractService.get_all(session)
) -> PaginatedResponse[Contract]:
if not fast_is_checksum_address(address):
raise HTTPException(status_code=400, detail="Address is not checksumed")

pagination = GenericPagination(pagination_params.limit, pagination_params.offset)
contracts_service = ContractService(pagination=pagination)
results, count = await contracts_service.get_contracts(
session, HexBytes(address), chain_ids
)
return pagination.serialize(request.url, results, count)
61 changes: 61 additions & 0 deletions app/routers/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pydantic import BaseModel, field_validator

from safe_eth.eth.utils import ChecksumAddress, fast_to_checksum_address


class About(BaseModel):
version: str


class ProjectPublic(BaseModel):
description: str
logo_file: str

class Config:
from_attributes = True


class AbiPublic(BaseModel):
abi_json: list[dict] | dict | None
abi_hash: bytes | str

class Config:
from_attributes = True

@field_validator("abi_hash")
@classmethod
def convert_bytes_to_hex(cls, abi_hash: bytes):
"""
Convert bytes to hex

:param abi_hash:
:return:
"""
if isinstance(abi_hash, bytes):
return "0x" + abi_hash.hex() # Convert bytes to a hex string
return abi_hash


class ContractsPublic(BaseModel):
address: bytes | ChecksumAddress
name: str
display_name: str | None
chain_id: int
project: ProjectPublic | None
abi: AbiPublic | None

class Config:
from_attributes = True

@field_validator("address")
@classmethod
def convert_to_checksum_address(cls, address: bytes):
"""
Convert bytes address to checksum address

:param address:
:return:
"""
if isinstance(address, bytes):
return fast_to_checksum_address(address)
return address
26 changes: 25 additions & 1 deletion app/services/contract.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Sequence
from typing import Sequence, Tuple

from sqlmodel.ext.asyncio.session import AsyncSession

from app.datasources.db.models import Contract

from .pagination import GenericPagination


class ContractService:

def __init__(self, pagination: GenericPagination):
self.pagination = pagination

@staticmethod
async def get_all(session: AsyncSession) -> Sequence[Contract]:
"""
Expand All @@ -16,3 +21,22 @@ async def get_all(session: AsyncSession) -> Sequence[Contract]:
:return:
"""
return await Contract.get_all(session)

async def get_contracts(
self, session: AsyncSession, address: bytes, chain_ids: list[int] | None
) -> Tuple[list[Contract], int]:
"""
Get the contract by address and/or chain_ids

:param session: database session
:param address: contract address
:param chain_ids: list of filtered chains
:return:
"""
page = await self.pagination.get_page(
session, Contract.get_contracts_query(address, chain_ids)
)
count = await self.pagination.get_count(
session, Contract.get_contracts_query(address, chain_ids)
)
return page, count
100 changes: 100 additions & 0 deletions app/services/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Any, Generic, TypeVar

from fastapi import Query
from pydantic import BaseModel

from sqlalchemy import func
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from starlette.datastructures import URL

T = TypeVar("T")


class PaginatedResponse(BaseModel, Generic[T]):
count: int
next: str | None
previous: str | None
results: list[T]


class PaginationQueryParams(BaseModel):
limit: int | None = Query(None, ge=1)
offset: int | None = Query(0, ge=0)


class GenericPagination:
def __init__(
self,
limit: int | None,
offset: int | None,
default_page_size: int = 10,
max_page_size: int = 100,
):
self.max_page_size = max_page_size
self.limit = min(limit, max_page_size) if limit else default_page_size
self.offset = offset if offset else 0

def get_next_page(self, url: URL, count: int) -> str | None:
"""
Calculates the next page of results. If there are no more pages return None

:param url:
:param count:
:return:
"""
if self.offset + self.limit < count:
next_offset = self.offset + self.limit
return str(url.include_query_params(limit=self.limit, offset=next_offset))
return None

def get_previous_page(self, url: URL) -> str | None:
"""
Calculates the previous page of results. If there are no more pages return None

:param url:
:return:
"""
if self.offset > 0:
prev_offset = max(0, self.offset - self.limit) # Prevent negative offset
return str(url.include_query_params(limit=self.limit, offset=prev_offset))
return None

async def get_page(self, session: AsyncSession, query) -> list[Any]:
"""
Get from database the requested page

:param session:
:param query:
:return:
"""
queryset = await session.exec(query.offset(self.offset).limit(self.limit))
return queryset.all()

async def get_count(self, session: AsyncSession, query) -> int:
"""
Get from database the count of rows that fit the query

:param session:
:param query:
:return:
"""
count_query = await session.exec(select(func.count()).where(query._whereclause))
return count_query.one()

def serialize(self, url: URL, results: list[Any], count: int) -> PaginatedResponse:
"""
Get serialized page of results.

:param url:
:param results:
:param count:
:return:
"""
paginated_response = PaginatedResponse(
count=count,
next=self.get_next_page(url, count),
previous=self.get_previous_page(url),
results=results,
)
return paginated_response
9 changes: 7 additions & 2 deletions app/tests/db/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class TestModel(DbAsyncConn):
async def test_contract(self, session: AsyncSession):
contract = Contract(address=b"a", name="A test contract", chain_id=1)
await contract.create(session)
await contract.create(session)
result = await contract.get_all(session)
self.assertEqual(result[0], contract)

Expand All @@ -23,7 +22,13 @@ async def test_project(self, session: AsyncSession):

@database_session
async def test_abi(self, session: AsyncSession):
abi = Abi(abi_hash=b"A Test Abi", abi_json={"name": "A Test Project"})
abi_source = AbiSource(name="A Test Source", url="https://test.com")
await abi_source.create(session)
abi = Abi(
abi_hash=b"A Test Abi",
abi_json={"name": "A Test Project"},
source_id=abi_source.id,
)
await abi.create(session)
result = await abi.get_all(session)
self.assertEqual(result[0], abi)
Expand Down
Empty file added app/tests/mocks/__init__.py
Empty file.
Loading
Loading