Skip to content

Commit

Permalink
#3 Added dedicated enum for categories.
Browse files Browse the repository at this point in the history
  • Loading branch information
hiker committed Apr 15, 2024
1 parent 5a22091 commit d288c0e
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 31 deletions.
1 change: 1 addition & 0 deletions source/fab/newtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
'''A simple init file to make it shorter to import tools.
'''

from fab.newtools.categories import Categories
from fab.newtools.compiler import Compiler, Gcc, Gfortran, Icc, Ifort
from fab.newtools.flags import Flags
from fab.newtools.tool import Tool
Expand Down
26 changes: 26 additions & 0 deletions source/fab/newtools/categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
##############################################################################
# (c) Crown copyright Met Office. All rights reserved.
# For further details please refer to the file COPYRIGHT
# which you should have received as part of this distribution
##############################################################################

'''This simple module defines an Enum for all allowed categories.
'''

from enum import auto, Enum


class Categories(Enum):
'''This class defines the allowed tool categories.'''

C_COMPILER = auto()
C_PREPROCESSOR = auto()
FORTRAN_COMPILER = auto()
FORTRAN_PREPROCESSOR = auto()
LINKER = auto()
PSYCLONE = auto()

def __str__(self):
'''Simplify the str output by using only the name (e.g. `C_COMPILER`
instead of `Categories.C_COMPILER)`.'''
return str(self.name)
6 changes: 3 additions & 3 deletions source/fab/newtools/tool_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
'''This file contains the ToolBox class.
'''

from fab.newtools import Tool, ToolRepository
from fab.newtools import Categories, Tool, ToolRepository


class ToolBox:
Expand All @@ -18,15 +18,15 @@ class ToolBox:
def __init__(self):
self._all_tools = {}

def add_tool(self, category: str, tool: Tool):
def add_tool(self, category: Categories, tool: Tool):
'''Adds a tool for a given category.
:param category: the category for which to add a tool
:param tool: the tool to add.
'''
self._all_tools[category] = tool

def get_tool(self, category: str):
def get_tool(self, category: Categories):
'''Returns the tool for the specified category.
:param category: the name of the category in which to look
Expand Down
20 changes: 9 additions & 11 deletions source/fab/newtools/tool_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
'''This file contains the ToolRepository class.
'''

from fab.newtools import Gcc, Gfortran, Icc, Ifort
from fab.newtools import Categories, Gcc, Gfortran, Icc, Ifort


class ToolRepository(dict):
'''This class implements the tool repository. It stores a list of
tools for various categories.
'''

C_COMPILER = "c-compiler"
FORTRAN_COMPILER = "fortran-compiler"

_singleton = None

@staticmethod
Expand All @@ -37,10 +34,10 @@ def __init__(self):
"the singleton instance.")
super().__init__()
# The first entry is the default
self[self.C_COMPILER] = [Gcc(), Icc()]
self[self.FORTRAN_COMPILER] = [Gfortran(), Ifort()]
self[Categories.C_COMPILER] = [Gcc(), Icc()]
self[Categories.FORTRAN_COMPILER] = [Gfortran(), Ifort()]

def get_tool(self, category: str, name: str):
def get_tool(self, category: Categories, name: str):
'''Returns the tool with a given name in the specified category.
:param category: the name of the category in which to look
Expand All @@ -62,15 +59,16 @@ def get_tool(self, category: str, name: str):
raise KeyError(f"Unknown tool '{name}' in category '{category}' "
f"in ToolRepository.")

def get_default(self, category: str):
def get_default(self, category: Categories):
'''Returns the default tool for a given category, which is just
the first tool in the category.
:param category: the category for which to return the default tool.
:raises KeyError: if the category does not exist.
'''
if category not in self:
raise KeyError(f"Unknown category '{category}' in "
f"ToolRepository.get_default.")

if not isinstance(category, Categories):
raise RuntimeError(f"Invalid category type "
f"'{type(category).__name__}'.")
return self[category][0]
18 changes: 18 additions & 0 deletions tests/unit_tests/tools/test_categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
##############################################################################
# (c) Crown copyright Met Office. All rights reserved.
# For further details please refer to the file COPYRIGHT
# which you should have received as part of this distribution
##############################################################################

'''This module tests the Categories.
'''

from fab.newtools import Categories


def test_categories():
'''Tests the categories.'''
# Make sure that str of a category only prints the name (which is more
# useful for error messages).
for cat in list(Categories):
assert str(cat) == cat.name
12 changes: 6 additions & 6 deletions tests/unit_tests/tools/test_tool_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
'''This module tests the TooBox class.
'''

from fab.newtools import ToolBox, ToolRepository
from fab.newtools import Categories, ToolBox, ToolRepository


def test_tool_box_constructor():
Expand All @@ -20,10 +20,10 @@ def test_tool_box_get_tool():
'''Tests get_tool.'''
tb = ToolBox()
tr = ToolRepository.get()
default_compiler = tb.get_tool(tr.FORTRAN_COMPILER)
assert default_compiler is tr.get_default(tr.FORTRAN_COMPILER)
default_compiler = tb.get_tool(Categories.FORTRAN_COMPILER)
assert default_compiler is tr.get_default(Categories.FORTRAN_COMPILER)

tr_gfortran = tr.get_tool(tr.FORTRAN_COMPILER, "gfortran")
tb.add_tool(tr.FORTRAN_COMPILER, tr_gfortran)
gfortran = tb.get_tool(tr.FORTRAN_COMPILER)
tr_gfortran = tr.get_tool(Categories.FORTRAN_COMPILER, "gfortran")
tb.add_tool(Categories.FORTRAN_COMPILER, tr_gfortran)
gfortran = tb.get_tool(Categories.FORTRAN_COMPILER)
assert gfortran is tr_gfortran
22 changes: 11 additions & 11 deletions tests/unit_tests/tools/test_tool_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest

from fab.newtools import (Gcc, Gfortran, Ifort, ToolRepository)
from fab.newtools import Categories, Gcc, Gfortran, Ifort, ToolRepository


def test_tool_repository_get_singleton():
Expand All @@ -31,17 +31,17 @@ def test_tool_repository_get_singleton():
def test_tool_repository_constructor():
'''Tests the ToolRepository constructor.'''
tr: ToolRepository = ToolRepository.get()
assert tr.C_COMPILER in tr
assert tr.FORTRAN_COMPILER in tr
assert Categories.C_COMPILER in tr
assert Categories.FORTRAN_COMPILER in tr


def test_tool_repository_get_tool():
'''Tests get_tool.'''
tr = ToolRepository.get()
gfortran = tr.get_tool(tr.FORTRAN_COMPILER, "gfortran")
gfortran = tr.get_tool(Categories.FORTRAN_COMPILER, "gfortran")
assert isinstance(gfortran, Gfortran)

ifort = tr.get_tool(tr.FORTRAN_COMPILER, "ifort")
ifort = tr.get_tool(Categories.FORTRAN_COMPILER, "ifort")
assert isinstance(ifort, Ifort)


Expand All @@ -53,24 +53,24 @@ def test_tool_repository_get_tool_error():
assert "Unknown category 'unknown-category'" in str(err.value)

with pytest.raises(KeyError) as err:
tr.get_tool(tr.C_COMPILER, "something")
assert ("Unknown tool 'something' in category 'c-compiler'"
tr.get_tool(Categories.C_COMPILER, "something")
assert ("Unknown tool 'something' in category 'C_COMPILER'"
in str(err.value))


def test_tool_repository_get_default():
'''Tests get_default.'''
tr = ToolRepository.get()
gfortran = tr.get_default("fortran-compiler")
gfortran = tr.get_default(Categories.FORTRAN_COMPILER)
assert isinstance(gfortran, Gfortran)

gcc = tr.get_default("c-compiler")
gcc = tr.get_default(Categories.C_COMPILER)
assert isinstance(gcc, Gcc)


def test_tool_repository_get_default_error():
'''Tests error handling in get_default.'''
tr = ToolRepository.get()
with pytest.raises(KeyError) as err:
with pytest.raises(RuntimeError) as err:
tr.get_default("unknown-category")
assert "Unknown category 'unknown-category'" in str(err.value)
assert "Invalid category type 'str'." in str(err.value)

0 comments on commit d288c0e

Please sign in to comment.