Skip to content

Commit

Permalink
update: fix inits
Browse files Browse the repository at this point in the history
  • Loading branch information
VsevolodX committed Oct 25, 2024
1 parent 784ba8c commit abfdb20
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
45 changes: 30 additions & 15 deletions src/py/mat3ra/standata/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, List, Optional

Check failure on line 2 in src/py/mat3ra/standata/base.py

View workflow job for this annotation

GitHub Actions / run-py-linter (3.8.6)

Ruff (F401)

src/py/mat3ra/standata/base.py:2:32: F401 `typing.Optional` imported but unused

import pandas as pd
from pydantic import BaseModel
from pydantic import BaseModel, Field


class StandataEntity(BaseModel):
Expand Down Expand Up @@ -30,6 +30,16 @@ def get_categories_as_list(self, separator: str = "/") -> List[str]:
category_groups = [list(map(lambda x: f"{key}{separator}{x}", val)) for key, val in self.categories.items()]
return [item for sublist in category_groups for item in sublist]

def get(self, key: str, default=None):
"""
Returns the value for the specified key if key is in the dictionary, else default.
Args:
key: The key to look for in the categories dictionary.
default: The value to return if the key is not found.
"""
return self.categories.get(key, default)

def convert_tags_to_categories_list(self, *tags: str):
"""
Converts simple tags to '<category_type>/<tag>' format.
Expand Down Expand Up @@ -96,7 +106,8 @@ def __lookup_table(self) -> pd.DataFrame:
return df


class StandataFilesMapByName(Dict[str, dict]):
class StandataFilesMapByName(BaseModel):
dictionary: Dict[str, dict] = Field(default_factory=dict)

def get_objects_by_filenames(self, filenames: List[str]) -> List[dict]:
"""
Expand All @@ -106,7 +117,7 @@ def get_objects_by_filenames(self, filenames: List[str]) -> List[dict]:
filenames: Filenames of the entities.
"""
matching_objects = []
for key, entity in self.items():
for key, entity in self.dictionary.items():
if key in filenames:
matching_objects.append(entity)
return matching_objects
Expand All @@ -116,33 +127,37 @@ class StandataData(BaseModel):
class Config:
arbitrary_types_allowed = True

filesMapByName: Optional[StandataFilesMapByName] = StandataFilesMapByName()
standataConfig: Optional[StandataConfig] = StandataConfig()
filesMapByName: StandataFilesMapByName = StandataFilesMapByName()
standataConfig: StandataConfig = StandataConfig()

def __init__(self, /, **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.filesMapByName = StandataFilesMapByName(kwargs.get("filesMapByName", {}))
self.standataConfig = StandataConfig(**kwargs.get("standataConfig", {}))
self.filesMapByName = StandataFilesMapByName(dictionary=kwargs.get("filesMapByName", {}))
self.standataConfig = StandataConfig(
categories=kwargs.get("standataConfig", {}).get("categories", {}),
entities=[
StandataEntity(filename=entity["filename"], categories=entity["categories"])
for entity in kwargs.get("standataConfig", {}).get("entities", [])
],
)


class Standata(BaseModel):
# Override in children
data: StandataData = StandataData()

@classmethod
def get_as_list(cls):
return list(cls.data.filesMapByName.values())
def get_as_list(self) -> List[dict]:
return list(self.data.filesMapByName.dictionary.values())

@classmethod
def get_by_name(cls, name: str) -> List[dict]:
def get_by_name(self, name: str) -> List[dict]:
"""
Returns entities by name regex.
Args:
name: Name of the entity.
"""
matching_filenames = cls.data.standataConfig.get_filenames_by_regex(name)
return cls.data.filesMapByName.get_objects_by_filenames(matching_filenames)
matching_filenames = self.data.standataConfig.get_filenames_by_regex(name)
return self.data.filesMapByName.get_objects_by_filenames(matching_filenames)

def get_by_categories(self, *tags: str) -> List[dict]:
"""
Expand Down
20 changes: 15 additions & 5 deletions src/py/mat3ra/standata/materials.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from .base import Standata, StandataData, StandataFilesMapByName
from typing import Dict

from .base import Standata, StandataConfig, StandataData, StandataEntity, StandataFilesMapByName
from .data.materials import materials_data


class Materials(Standata):
data: StandataData = StandataData(
filesMapByName=StandataFilesMapByName(materials_data["filesMapByName"]),
standataConfig=materials_data.get("standataConfig", {}),
)
def __init__(self, data: Dict = materials_data):
standata_data = StandataData(
filesMapByName=StandataFilesMapByName(dictionary=data.get("filesMapByName", {})),
standataConfig=StandataConfig(
categories=data.get("standataConfig", {}).get("categories", {}),
entities=[
StandataEntity(filename=entity["filename"], categories=entity["categories"])
for entity in data.get("standataConfig", {}).get("entities", [])
],
),
)
super().__init__(data=standata_data)

0 comments on commit abfdb20

Please sign in to comment.