Skip to content

Commit

Permalink
[Refactor] remote registries towards upstream (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
g5t authored Dec 4, 2023
1 parent 3d8d255 commit c2e1e69
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 72 deletions.
36 changes: 20 additions & 16 deletions mccode_antlr/io/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,25 @@ def save(group, data, **kwargs):
HDF5IO.save(group=group.create_group('member_names'), data=[member.name for member in data.members])


class RemoteRegistryIO:
from mccode_antlr.reader.registry import Registry, RemoteRegistry
def RemoteRegistryIO(actual_type):
from mccode_antlr.reader.registry import RemoteRegistry

@staticmethod
def load(group, **kwargs) -> Registry:
values = _standard_read(RemoteRegistryIO.RemoteRegistry, group, ('name', 'filename', 'url'), (), (), **kwargs)
try:
return RemoteRegistryIO.RemoteRegistry(**values)
except RuntimeError:
log.warn(f'Unable to reconstruct remote registry from {values}')
return RemoteRegistryIO.Registry()
class _RemoteRegistryIO:
@staticmethod
def load(group, **kwargs) -> RemoteRegistry:
_check_header(group, actual_type)
values = _standard_read(actual_type, group, actual_type.file_keys(), (), (), **kwargs)
try:
return actual_type(**values)
except RuntimeError:
log.warn(f'Unable to reconstruct {actual_type.__name__} registry from {values}')
return RemoteRegistry('loading error', None, None, None)

@staticmethod
def save(group, data, **kwargs):
_standard_save(RemoteRegistryIO.RemoteRegistry, group, data, ('name', 'filename'), (), **kwargs)
if data.pooch.base_url is not None:
group.attrs['url'] = data.pooch.base_url
@staticmethod
def save(group, data, **kwargs):
_standard_save(actual_type, group, data, actual_type.file_keys(), (), **kwargs)

return _RemoteRegistryIO


class LocalRegistryIO:
Expand Down Expand Up @@ -397,6 +399,7 @@ class HDF5IO:
from mccode_antlr.instr.jump import Jump
from mccode_antlr.instr.orientation import (Matrix, Vector, Angles, Rotation, Seitz, RotationX, RotationY,
RotationZ, TranslationPart, Orient, Parts, Part)
from mccode_antlr.reader.registry import ModuleRemoteRegistry, GitHubRegistry

_handlers = {
'Instr': InstrIO,
Expand All @@ -406,7 +409,8 @@ class HDF5IO:
'Instance': InstanceIO,
'RawC': _dataclass_io(RawC, attrs=('filename', 'line'), required=('source',), optional=('translated',)),
'Group': GroupIO,
'RemoteRegistry': RemoteRegistryIO,
'ModuleRemoteRegistry': RemoteRegistryIO(ModuleRemoteRegistry),
'GitHubRegistry': RemoteRegistryIO(GitHubRegistry),
'LocalRegistry': LocalRegistryIO,
'ComponentParameter': _dataclass_io(ComponentParameter, attrs=('name',), required=('value',)),
'Comp': _dataclass_io(Comp, attrs=('name', 'category', 'dependency', 'acc'),
Expand Down
8 changes: 6 additions & 2 deletions mccode_antlr/reader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from .reader import Reader
from .registry import Registry, LocalRegistry, RemoteRegistry, MCSTAS_REGISTRY, MCXTRACE_REGISTRY, LIBC_REGISTRY
from .registry import (Registry, LocalRegistry, RemoteRegistry, ModuleRemoteRegistry, GitHubRegistry,
MCSTAS_REGISTRY, MCXTRACE_REGISTRY, LIBC_REGISTRY, FIXED_LIBC_REGISTRY)

__all__ = [
'Reader',
'Registry',
'LocalRegistry',
'RemoteRegistry',
'ModuleRemoteRegistry',
'GitHubRegistry',
'MCSTAS_REGISTRY',
'MCXTRACE_REGISTRY',
'LIBC_REGISTRY',
]
'FIXED_LIBC_REGISTRY',
]
182 changes: 133 additions & 49 deletions mccode_antlr/reader/registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pooch
from pathlib import Path
from re import Pattern
Expand All @@ -13,6 +15,27 @@ def ensure_regex_pattern(pattern):
return pattern


def simple_url_validator(url: str, file_ok=False):
from urllib.parse import urlparse

if not isinstance(url, str):
return False
try:
result = urlparse(url)
except AttributeError:
return False
if not result.scheme:
return False
if file_ok:
if result.scheme == 'file':
print("Constructing a RemoteRegistry for a file:// URL will likely duplicate files!")
if result.scheme != 'file' and not result.netloc:
return False
elif not result.netloc:
return False
return True


class Registry:
name = None
root = None
Expand Down Expand Up @@ -69,7 +92,7 @@ def _name_plus_suffix(name: str, suffix: str = None):

def find_registry_file(name: str):
"""Find a registry file in the mccode_antlr package"""
from importlib.resources import files, as_file
from importlib.resources import files
from importlib.metadata import distribution
from json import loads
if isinstance(name, Path):
Expand All @@ -84,31 +107,25 @@ def find_registry_file(name: str):


class RemoteRegistry(Registry):
def __init__(self, name: str, url: str, filename=None, version=None):
def __init__(self, name: str, url: str | None, version: str | None, filename: str | None):
self.name = name
self.url = url
self.version = version
self.filename = filename
self.pooch = pooch.create(
path=pooch.os_cache(f'mccode_antlr-{name}'),
base_url=url,
version=version or mccode_antlr_version(),
version_dev="main",
registry=None,
)
if isinstance(filename, Path):
self.pooch.load_registry(filename)
else:
filepath = find_registry_file(filename)
if filepath is None:
raise RuntimeError(f"The provided filename {filename} is not a path or file packaged with this module")
self.pooch.load_registry(filepath)
self.pooch = None

@classmethod
def file_keys(cls) -> tuple[str, ...]:
return 'name', 'url', 'version', 'filename'

def file_contents(self) -> dict[str, str]:
return {key: getattr(self, key) or '' for key in self.file_keys()}

def to_file(self, output, wrapper):
contents = '(' + ', '.join([
wrapper.parameter('name') + '=' + wrapper.value(self.name),
wrapper.parameter('url') + ('' if self.pooch is None else ('=' + wrapper.url(self.pooch.base_url))),
wrapper.parameter('filename') + '=' + wrapper.value(self.filename),
]) + ')'
print(wrapper.line('RemoteRegistry', [contents], ''), file=output)
wp = wrapper.parameter
wv = wrapper.value
contents = '(' + ', '.join([wp(n) + '=' + wv(v) for n, v in self.file_contents().items()]) + ')'
print(wrapper.line(self.__class__.__name__, [contents], ''), file=output)

def known(self, name: str, ext: str = None):
compare = _name_plus_suffix(name, ext)
Expand Down Expand Up @@ -173,6 +190,68 @@ def __eq__(self, other):
return True


class ModuleRemoteRegistry(RemoteRegistry):
def __init__(self, name: str, url: str, filename=None, version=None):
super().__init__(name, url, version, filename)
self.pooch = pooch.create(
path=pooch.os_cache(f'mccode_antlr-{self.name}'),
base_url=self.url,
version=self.version or mccode_antlr_version(),
version_dev="main",
registry=None,
)
if isinstance(self.filename, Path):
self.pooch.load_registry(self.filename)
else:
filepath = find_registry_file(self.filename)
if filepath is None:
raise RuntimeError(f"Provided filename {self.filename} is not a path or file packaged with this module")
self.pooch.load_registry(filepath)


class GitHubRegistry(RemoteRegistry):
def __init__(self, name: str, url: str, version: str, filename: str | None = None,
registry: str | dict | None = None):
if filename is None:
filename = f'{name}-registry.txt'
super().__init__(name, url, version, filename)
import requests
base_url = f'{self.url}/raw/{self.version}/'
# If registry is a string url, we expect the registry file to be available from _that_ url
self._stashed_registry = None
if isinstance(registry, str) and simple_url_validator(registry, file_ok=True):
self._stashed_registry = registry
registry = f'{registry}/raw/{self.version}/'
# We allow a full-dictionary to be provided, otherwise we expect the registry file to be available from the
# base_url where all subsequent files are also expected to be available
if not isinstance(registry, dict):
r = requests.get((registry or base_url) + (self.filename or 'pooch-registry.txt'))
if not r.ok:
raise RuntimeError(f"Could not retrieve {r.url} because {r.reason}")
registry = {k: v for k, v in [x.split(maxsplit=1) for x in r.text.split('\n') if len(x)]}

self.pooch = pooch.create(
path=pooch.os_cache(self.name),
base_url=base_url,
version=version,
version_dev="main",
registry=registry,
)

@property
def registry(self):
return self._stashed_registry

def file_contents(self) -> dict[str, str]:
fc = super().file_contents()
fc['registry'] = self._stashed_registry or ''
return fc

@classmethod
def file_keys(cls) -> tuple[str, ...]:
return super().file_keys() + ('registry',)


class LocalRegistry(Registry):
def __init__(self, name: str, root: str):
self.name = name
Expand Down Expand Up @@ -265,54 +344,59 @@ def registry_from_specification(spec: str):
1. {resolvable folder path}
2. {name} {resolvable folder path}
3. {name} {resolvable url} {resolvable file path}
4. {name} {resolvable url} {version} {registry file name}
The first two variants make a LocalRegistry, which searches the provided directory for files.
The last makes a RemoteRegistry using pooch. The resolvable file path should point at a Pooch registry file.
The third makes a ModuleRemoteRegistry using pooch. The resolvable file path should point at a Pooch registry file.
The fourth makes a GitHubRegistry, which uses the specific folder structure of GitHub
"""
from urllib.parse import urlparse

if isinstance(spec, Registry):
return spec
parts = spec.split()
if len(parts) == 0:
return None
elif len(parts) == 1:
p1, p2, p3 = parts[0], parts[0], None
p1, p2, p3, p4 = parts[0], parts[0], None, None
elif len(parts) < 4:
p1, p2, p3, p4 = parts[0], parts[1], None if len(parts) < 3 else parts[2], None
else:
p1, p2, p3 = parts[0], parts[1], None if len(parts) < 3 else parts[2]
print(f"Constructing registry from {p1=} {p2=} {p3=} [{type(p1)=} {type(p2)=} {type(p3)=}]")
p1, p2, p3, p4 = parts[0], parts[1], parts[2], parts[3]
print(f"Constructing registry from {p1=} {p2=} {p3=} {p4=} [{type(p1)=} {type(p2)=} {type(p3)=} {type(p4)=}]")
# convert string literals to strings:
p1 = p1[1:-1] if p1.startswith('"') and p1.endswith('"') else p1
p2 = p2[1:-1] if p2.startswith('"') and p2.endswith('"') else p2
p3 = p3[1:-1] if p3 is not None and p3.startswith('"') and p3.endswith('"') else p3
p4 = p4[1:-1] if p4 is not None and p4.startswith('"') and p4.endswith('"') else p4

if Path(p2).exists() and Path(p2).is_dir():
return LocalRegistry(p1, str(Path(p2).resolve()))

# (simple) URL validation:
if not isinstance(p2, str):
return False
try:
result = urlparse(p2)
except AttributeError:
return False
if not result.scheme:
return False
if result.scheme == 'file':
print("Constructing a RemoteRegistry for a file:// URL will likely duplicate files!")
if result.scheme != 'file' and not result.netloc:
return False
if not simple_url_validator(p2, file_ok=True):
return None

if Path(p3).exists() and Path(p3).is_file():
return RemoteRegistry(p1, p2, Path(p3).resolve())
return ModuleRemoteRegistry(p1, p2, Path(p3).resolve().as_posix())

if p4 is not None:
return GitHubRegistry(p1, p2, p3, p4)

return None


# Pre-defined registry files:
REMOTE_REPOSITORY = 'https://github.com/g5t/mccode-files/raw/main'
# McStas components, instruments, and translation-time include files
MCSTAS_REGISTRY = RemoteRegistry('mcstas', f'{REMOTE_REPOSITORY}/mcstas', 'mcstas-registry.txt')
# McXtrace components, instruments, and translation-time include files
MCXTRACE_REGISTRY = RemoteRegistry('mcxtrace', f'{REMOTE_REPOSITORY}/mcxtrace', 'mcxtrace-registry.txt')
# Common runtime components for C
LIBC_REGISTRY = RemoteRegistry('libc', f'{REMOTE_REPOSITORY}/runtime/libc', 'libc-registry.txt')
def _m_reg(name):
m_url = "https://github.com/McStasMcXtrace/McCode"
r_url = "https://github.com/g5t/mccode-pooch"
# TODO update this version with new McStasMcXtrace releases
return GitHubRegistry(name, m_url, 'v3.4.0', registry=r_url)


# TODO remove this registry once mcstas-d.h and mcxtrace-d.h are in the mccode repo; plus mccode-r.h is not configured
FIXED_LIBC_REGISTRY = ModuleRemoteRegistry(
'fixed-libc',
'https://github.com/g5t/mccode-files/raw/main/runtime/libc',
'libc-registry.txt'
)

MCSTAS_REGISTRY, MCXTRACE_REGISTRY, LIBC_REGISTRY = [_m_reg(name) for name in ('mcstas', 'mcxtrace', 'libc')]
del _m_reg
4 changes: 3 additions & 1 deletion mccode_antlr/translators/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from zenlog import log
from collections import namedtuple
from dataclasses import dataclass
from ..reader import Registry, LIBC_REGISTRY
from ..reader import Registry, FIXED_LIBC_REGISTRY, LIBC_REGISTRY
from ..instr import Instr, Instance
from .target import TargetVisitor
from .c_listener import extract_c_declared_variables
Expand Down Expand Up @@ -245,6 +245,8 @@ def __post_init__(self):
languages. (A different target language would not include the same libraries in its raw blocks)
"""
# Make sure the registry list contains the C library registry, so that we can find and include files
if not any(reg == FIXED_LIBC_REGISTRY for reg in self.registries):
self.source.registries += (FIXED_LIBC_REGISTRY, )
if not any(reg == LIBC_REGISTRY for reg in self.registries):
self.source.registries += (LIBC_REGISTRY, )

Expand Down
16 changes: 14 additions & 2 deletions mccode_antlr/translators/c_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def jump_line(instance, jump):
uservar_string += '\n'.join([f' {x.type} {x.name};' for x in uservars])

# Shouldn't we exclude array-valued names here?
getvar = '\n'.join([f' if(!str_comp("{x.name}",name)){{rval=*((double*)(&(p->{x.name})));s=0;}}' for x in uservars])
getvar = '\n'.join(
[f' if(!str_comp("{x.name}",name)){{rval=*((double*)(&(p->{x.name})));s=0;}}' for x in uservars])

# Array valued names seem OK here, since a user can type-cast correctly
getvar_void = '\n'.join(
Expand Down Expand Up @@ -89,9 +90,17 @@ def jump_line(instance, jump):
struct _struct_particle {{
double x,y,z; /* position [m] */
{particle_struct}
/* Generic Temporaries: */
/* May be used internally by components e.g. for special */
/* return-values from functions used in trace, thusreturned via */
/* particle struct. (Example: Wolter Conics from McStas, silicon slabs.) */
double _mctmp_a; /* temp a */
double _mctmp_b; /* temp b */
double _mctmp_c; /* temp c */
unsigned long randstate[7];
double t, p; /* time, event weight */
long long _uid; /* event ID */
long long _uid; /* Unique event ID */
long long _loopid; /* inner-loop event ID */
long _index; /* component index where to send this event */
long _absorbed; /* flag set to TRUE when this event is to be removed/ignored */
long _scattered; /* flag set to TRUE when this event has interacted with the last component instance */
Expand Down Expand Up @@ -149,6 +158,9 @@ def jump_line(instance, jump):
if(!str_comp("sz",name)){{rval=p->sz;s=0;}}
if(!str_comp("t",name)){{rval=p->t;s=0;}}
if(!str_comp("p",name)){{rval=p->p;s=0;}}
if(!str_comp("_mctmp_a",name)){{rval=p->_mctmp_a;s=0;}}
if(!str_comp("_mctmp_b",name)){{rval=p->_mctmp_b;s=0;}}
if(!str_comp("_mctmp_c",name)){{rval=p->_mctmp_c;s=0;}}
{getvar}
if (suc!=0x0) {{*suc=s;}}
return rval;
Expand Down
Loading

0 comments on commit c2e1e69

Please sign in to comment.