Skip to content

Commit

Permalink
#189 Predict directly from DataFrame objects
Browse files Browse the repository at this point in the history
  • Loading branch information
ofrancon committed Jan 22, 2021
1 parent ed03d55 commit 77842ad
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 31 deletions.
26 changes: 7 additions & 19 deletions covid_xprize/scoring/prescriptor_scoring.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os

import pandas as pd

from covid_xprize.standard_predictor.predict import predict
from covid_xprize.standard_predictor.xprize_predictor import XPrizePredictor
from covid_xprize.standard_predictor.xprize_predictor import NPI_COLUMNS


Expand All @@ -17,28 +15,18 @@ def weight_prescriptions_by_cost(pres_df, cost_df):


def generate_cases_and_stringency_for_prescriptions(start_date, end_date, prescription_file, costs_file):
# Load prescriptions
pres_df = pd.read_csv(prescription_file)
# Load the prescriptions, handling Date and regions
pres_df = XPrizePredictor.load_original_data(prescription_file)

# Generate predictions for all prescriptions
predictor = XPrizePredictor()
pred_dfs = []
for idx in pres_df['PrescriptionIndex'].unique():
idx_df = pres_df[pres_df['PrescriptionIndex'] == idx]
idx_df = idx_df.drop(columns='PrescriptionIndex') # Predictor doesn't need this
ip_file_path = 'prescriptions/prescription_{}.csv'.format(idx)
os.makedirs(os.path.dirname(ip_file_path), exist_ok=True)
idx_df.to_csv(ip_file_path)
preds_file_path = 'predictions/predictions_{}.csv'.format(idx)
os.makedirs(os.path.dirname(preds_file_path), exist_ok=True)

# Run predictor
predict(start_date, end_date, ip_file_path, preds_file_path)

# Collect predictions
pred_df = pd.read_csv(preds_file_path,
parse_dates=['Date'],
encoding="ISO-8859-1",
error_bad_lines=True)
# Generate the predictions
pred_df = predictor.predict_from_df(start_date, end_date, idx_df)
print(f"Generated predictions for PrescriptionIndex {idx}")
pred_df['PrescriptionIndex'] = idx
pred_dfs.append(pred_df)
pred_df = pd.concat(pred_dfs)
Expand Down
7 changes: 1 addition & 6 deletions covid_xprize/standard_predictor/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))

# Fixed weights for the standard predictor.
MODEL_WEIGHTS_FILE = os.path.join(ROOT_DIR, "models", "trained_model_weights.h5")

DATA_FILE = os.path.join(ROOT_DIR, 'data', "OxCGRT_latest.csv")


def predict(start_date: str,
end_date: str,
Expand All @@ -29,7 +24,7 @@ def predict(start_date: str,
with columns "CountryName,RegionName,Date,PredictedDailyNewCases"
"""
# !!! YOUR CODE HERE !!!
predictor = XPrizePredictor(MODEL_WEIGHTS_FILE, DATA_FILE)
predictor = XPrizePredictor()
# Generate the predictions
preds_df = predictor.predict(start_date, end_date, path_to_ips_file)
# Create the output path
Expand Down
19 changes: 13 additions & 6 deletions covid_xprize/standard_predictor/xprize_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
ADDITIONAL_US_STATES_CONTEXT = os.path.join(DATA_PATH, "US_states_populations.csv")
ADDITIONAL_UK_CONTEXT = os.path.join(DATA_PATH, "uk_populations.csv")
ADDITIONAL_BRAZIL_CONTEXT = os.path.join(DATA_PATH, "brazil_populations.csv")
# Fixed weights for the standard predictor.
MODEL_WEIGHTS_FILE = os.path.join(ROOT_DIR, "models", "trained_model_weights.h5")

NPI_COLUMNS = ['C1_School closing',
'C2_Workplace closing',
Expand Down Expand Up @@ -72,7 +74,7 @@ class XPrizePredictor(object):
A class that computes a fitness for Prescriptor candidates.
"""

def __init__(self, path_to_model_weights, data_url):
def __init__(self, path_to_model_weights=MODEL_WEIGHTS_FILE, data_url=DATA_FILE_PATH):
if path_to_model_weights:

# Load model weights
Expand All @@ -94,13 +96,18 @@ def predict(self,
start_date_str: str,
end_date_str: str,
path_to_ips_file: str) -> pd.DataFrame:
# Load the npis into a DataFrame, handling regions
npis_df = self.load_original_data(path_to_ips_file)
return self.predict_from_df(start_date_str, end_date_str, npis_df)

def predict_from_df(self,
start_date_str: str,
end_date_str: str,
npis_df: pd.DataFrame) -> pd.DataFrame:
start_date = pd.to_datetime(start_date_str, format='%Y-%m-%d')
end_date = pd.to_datetime(end_date_str, format='%Y-%m-%d')
nb_days = (end_date - start_date).days + 1

# Load the npis into a DataFrame, handling regions
npis_df = self._load_original_data(path_to_ips_file)

# Prepare the output
forecast = {"CountryName": [],
"RegionName": [],
Expand Down Expand Up @@ -177,7 +184,7 @@ def _prepare_dataframe(self, data_url: str) -> pd.DataFrame:
:return: a Pandas DataFrame with the historical data
"""
# Original df from Oxford
df1 = self._load_original_data(data_url)
df1 = self.load_original_data(data_url)

# Additional context df (e.g Population for each country)
df2 = self._load_additional_context_df()
Expand Down Expand Up @@ -224,7 +231,7 @@ def _prepare_dataframe(self, data_url: str) -> pd.DataFrame:
return df

@staticmethod
def _load_original_data(data_url):
def load_original_data(data_url):
latest_df = pd.read_csv(data_url,
parse_dates=['Date'],
encoding="ISO-8859-1",
Expand Down

0 comments on commit 77842ad

Please sign in to comment.