Skip to content

Commit

Permalink
Remove debugging leftover
Browse files Browse the repository at this point in the history
  • Loading branch information
jjlk committed Oct 2, 2024
1 parent 040f2eb commit 853bb79
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __exit__(self, *args):


class ArchiveCollector:
UNIQUE = {"date", "hdate", "time", "referenceDate", "type", "stream", "expver"}
UNIQUE = {"date", "hdate", "time",
"referenceDate", "type", "stream", "expver"}

def __init__(self) -> None:
self.expect = 0
Expand All @@ -55,7 +56,8 @@ def add(self, field):
self.request[k].add(str(v))
if k in self.UNIQUE:
if len(self.request[k]) > 1:
raise ValueError(f"Field {field} has different values for {k}: {self.request[k]}")
raise ValueError(
f"Field {field} has different values for {k}: {self.request[k]}")


class Model:
Expand Down Expand Up @@ -160,7 +162,8 @@ def json_default(obj):
raise TypeError

print(
json.dumps(json_requests, separators=(",", ":"), default=json_default, sort_keys=True),
json.dumps(json_requests, separators=(
",", ":"), default=json_default, sort_keys=True),
file=f,
)

Expand All @@ -170,7 +173,8 @@ def download_assets(self, **kwargs):
if not os.path.exists(asset):
os.makedirs(os.path.dirname(asset), exist_ok=True)
LOG.info("Downloading %s", asset)
download(self.download_url.format(file=file), asset + ".download")
download(self.download_url.format(
file=file), asset + ".download")
os.rename(asset + ".download", asset)

@property
Expand Down Expand Up @@ -443,7 +447,8 @@ def _requests(self):
def filter_constant(request):
# We check for 'sfc' because param 'z' can be ambiguous
if request.get("levtype") == "sfc":
param = set(self.constant_fields) & set(request.get("param", []))
param = set(self.constant_fields) & set(
request.get("param", []))
if param:
request["param"] = list(param)
return True
Expand All @@ -454,7 +459,8 @@ def filter_prognostic(request):
# TODO: We assume here that prognostic fields are
# the ones that are not constant. This may not always be true
if request.get("levtype") == "sfc":
param = set(request.get("param", [])) - set(self.constant_fields)
param = set(request.get("param", [])) - \
set(self.constant_fields)
if param:
request["param"] = list(param)
return True
Expand Down Expand Up @@ -496,7 +502,8 @@ def peek_into_checkpoint(self, path):

def parse_model_args(self, args):
if args:
raise NotImplementedError(f"This model does not accept arguments {args}")
raise NotImplementedError(
f"This model does not accept arguments {args}")

def provenance(self):
from .provenance import gather_provenance_info
Expand Down Expand Up @@ -542,8 +549,6 @@ def write_input_fields(
if ignore is None:
ignore = []

fields.save("input.grib")

with self.timer("Writing step 0"):
for field in fields:
if field.metadata("shortName") in ignore:
Expand Down Expand Up @@ -592,7 +597,8 @@ def write_input_fields(
"""

template = base64.b64decode(template)
accumulations_template = ekd.from_source("memory", template)[0]
accumulations_template = ekd.from_source(
"memory", template)[0]

for param in accumulations:
self.write(
Expand Down

0 comments on commit 853bb79

Please sign in to comment.