Skip to content

Commit

Permalink
Allow for different configurations of DeployablePredictionAgent.same_…
Browse files Browse the repository at this point in the history
…market_trade_interval (#531)
  • Loading branch information
evangriffiths authored Oct 29, 2024
1 parent e9bce61 commit ae454b3
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
10 changes: 8 additions & 2 deletions prediction_market_agent_tooling/deploy/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
gcp_function_is_active,
gcp_resolve_api_keys_secrets,
)
from prediction_market_agent_tooling.deploy.trade_interval import (
FixedInterval,
TradeInterval,
)
from prediction_market_agent_tooling.gtypes import xDai, xdai_type
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.markets.agent_market import (
Expand Down Expand Up @@ -281,7 +285,7 @@ class DeployablePredictionAgent(DeployableAgent):
bet_on_n_markets_per_run: int = 1
min_balance_to_keep_in_native_currency: xDai | None = xdai_type(0.1)
allow_invalid_questions: bool = False
same_market_bet_interval: timedelta = timedelta(hours=24)
same_market_trade_interval: TradeInterval = FixedInterval(timedelta(hours=24))
# Only Metaculus allows to post predictions without trading (buying/selling of outcome tokens).
supported_markets: t.Sequence[MarketType] = [MarketType.METACULUS]

Expand Down Expand Up @@ -345,7 +349,9 @@ def verify_market(self, market_type: MarketType, market: AgentMarket) -> bool:
Subclasses can implement their own logic instead of this one, or on top of this one.
By default, it allows only markets where user didn't bet recently and it's a reasonable question.
"""
if self.have_bet_on_market_since(market, since=self.same_market_bet_interval):
if self.have_bet_on_market_since(
market, since=self.same_market_trade_interval.get(market=market)
):
return False

# Manifold allows to bet only on markets with probability between 1 and 99.
Expand Down
46 changes: 46 additions & 0 deletions prediction_market_agent_tooling/deploy/trade_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from abc import ABC, abstractmethod
from datetime import timedelta

from prediction_market_agent_tooling.markets.agent_market import AgentMarket
from prediction_market_agent_tooling.tools.utils import check_not_none


class TradeInterval(ABC):
@abstractmethod
def get(
self,
market: AgentMarket,
) -> timedelta:
raise NotImplementedError("Subclass should implement this.")


class FixedInterval(TradeInterval):
"""
For trades at a fixed interval.
"""

def __init__(self, interval: timedelta):
self.interval = interval

def get(
self,
market: AgentMarket,
) -> timedelta:
return self.interval


class MarketLifetimeProportionalInterval(TradeInterval):
"""
For uniformly distributed trades over the market's lifetime.
"""

def __init__(self, max_trades: int):
self.max_trades = max_trades

def get(
self,
market: AgentMarket,
) -> timedelta:
created_time = check_not_none(market.created_time)
close_time = check_not_none(market.close_time)
return (close_time - created_time) / self.max_trades
32 changes: 32 additions & 0 deletions tests/test_trade_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from datetime import timedelta
from unittest.mock import Mock

import pytest

from prediction_market_agent_tooling.deploy.trade_interval import (
FixedInterval,
MarketLifetimeProportionalInterval,
)
from prediction_market_agent_tooling.markets.agent_market import AgentMarket
from prediction_market_agent_tooling.tools.utils import utcnow


@pytest.fixture
def mock_market() -> Mock:
mock_market = Mock(AgentMarket, wraps=AgentMarket)
mock_market.created_time = utcnow()
mock_market.close_time = mock_market.created_time + timedelta(days=10)
return mock_market


def test_fixed_interval(mock_market: Mock) -> None:
interval = timedelta(days=1)
fixed_interval = FixedInterval(interval=interval)

assert fixed_interval.get(mock_market) == interval


def test_market_lifetime_proportional_interval(mock_market: Mock) -> None:
proportional_interval = MarketLifetimeProportionalInterval(max_trades=5)

assert proportional_interval.get(mock_market) == timedelta(days=2)

0 comments on commit ae454b3

Please sign in to comment.