-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow for different configurations of DeployablePredictionAgent.same_…
…market_trade_interval (#531)
- Loading branch information
1 parent
e9bce61
commit ae454b3
Showing
3 changed files
with
86 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from abc import ABC, abstractmethod | ||
from datetime import timedelta | ||
|
||
from prediction_market_agent_tooling.markets.agent_market import AgentMarket | ||
from prediction_market_agent_tooling.tools.utils import check_not_none | ||
|
||
|
||
class TradeInterval(ABC): | ||
@abstractmethod | ||
def get( | ||
self, | ||
market: AgentMarket, | ||
) -> timedelta: | ||
raise NotImplementedError("Subclass should implement this.") | ||
|
||
|
||
class FixedInterval(TradeInterval): | ||
""" | ||
For trades at a fixed interval. | ||
""" | ||
|
||
def __init__(self, interval: timedelta): | ||
self.interval = interval | ||
|
||
def get( | ||
self, | ||
market: AgentMarket, | ||
) -> timedelta: | ||
return self.interval | ||
|
||
|
||
class MarketLifetimeProportionalInterval(TradeInterval): | ||
""" | ||
For uniformly distributed trades over the market's lifetime. | ||
""" | ||
|
||
def __init__(self, max_trades: int): | ||
self.max_trades = max_trades | ||
|
||
def get( | ||
self, | ||
market: AgentMarket, | ||
) -> timedelta: | ||
created_time = check_not_none(market.created_time) | ||
close_time = check_not_none(market.close_time) | ||
return (close_time - created_time) / self.max_trades |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from datetime import timedelta | ||
from unittest.mock import Mock | ||
|
||
import pytest | ||
|
||
from prediction_market_agent_tooling.deploy.trade_interval import ( | ||
FixedInterval, | ||
MarketLifetimeProportionalInterval, | ||
) | ||
from prediction_market_agent_tooling.markets.agent_market import AgentMarket | ||
from prediction_market_agent_tooling.tools.utils import utcnow | ||
|
||
|
||
@pytest.fixture | ||
def mock_market() -> Mock: | ||
mock_market = Mock(AgentMarket, wraps=AgentMarket) | ||
mock_market.created_time = utcnow() | ||
mock_market.close_time = mock_market.created_time + timedelta(days=10) | ||
return mock_market | ||
|
||
|
||
def test_fixed_interval(mock_market: Mock) -> None: | ||
interval = timedelta(days=1) | ||
fixed_interval = FixedInterval(interval=interval) | ||
|
||
assert fixed_interval.get(mock_market) == interval | ||
|
||
|
||
def test_market_lifetime_proportional_interval(mock_market: Mock) -> None: | ||
proportional_interval = MarketLifetimeProportionalInterval(max_trades=5) | ||
|
||
assert proportional_interval.get(mock_market) == timedelta(days=2) |