Skip to content

Commit

Permalink
apply black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
jrcastro2 committed Jul 19, 2024
1 parent ad40655 commit 9b4f9ca
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 104 deletions.
24 changes: 16 additions & 8 deletions invenio_vocabularies/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import click
from flask.cli import with_appcontext
from invenio_access.permissions import system_identity
from invenio_logging.structlog import LoggerFactory
from invenio_pidstore.errors import PIDDeletedError, PIDDoesNotExistError

from .datastreams import DataStreamFactory
from .factories import get_vocabulary_config
from invenio_logging.structlog import LoggerFactory


@click.group()
def vocabularies():
Expand All @@ -35,9 +36,9 @@ def _process_vocab(config, num_samples=None):
cli_logger.info("Starting processing")
success, errored, filtered = 0, 0, 0
left = num_samples or -1
batch_size=config.get("batch_size", 1000)
write_many=config.get("write_many", False)
batch_size = config.get("batch_size", 1000)
write_many = config.get("write_many", False)

for result in ds.process(batch_size=batch_size, write_many=write_many):
left = left - 1
if result.filtered:
Expand All @@ -46,15 +47,22 @@ def _process_vocab(config, num_samples=None):
if result.errors:
for err in result.errors:
click.secho(err, fg="red")
cli_logger.error("Error", entry=result.entry, operation=result.op_type, errors=result.errors)
cli_logger.error(
"Error",
entry=result.entry,
operation=result.op_type,
errors=result.errors,
)
errored += 1
else:
success += 1
cli_logger.info("Success", entry=result.entry, operation=result.op_type)
if left == 0:
click.secho(f"Number of samples reached {num_samples}", fg="green")
break
cli_logger.info("Finished processing", success=success, errored=errored, filtered=filtered)
cli_logger.info(
"Finished processing", success=success, errored=errored, filtered=filtered
)

return success, errored, filtered

Expand Down Expand Up @@ -159,7 +167,7 @@ def delete(vocabulary, identifier, all):
if not identifier and not all:
click.secho("An identifier or the --all flag must be present.", fg="red")
exit(1)

vc = get_vocabulary_config(vocabulary)
service = vc.get_service()
if identifier:
Expand All @@ -175,4 +183,4 @@ def delete(vocabulary, identifier, all):
if service.delete(system_identity, item["id"]):
click.secho(f"{item['id']} deleted from {vocabulary}.", fg="green")
except (PIDDeletedError, PIDDoesNotExistError):
click.secho(f"PID {item['id']} not found.")
click.secho(f"PID {item['id']} not found.")
10 changes: 5 additions & 5 deletions invenio_vocabularies/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ZipReader,
)
from .datastreams.transformers import XMLTransformer
from .datastreams.writers import AsyncWriter, AsyncWriter, ServiceWriter, YamlWriter
from .datastreams.writers import AsyncWriter, ServiceWriter, YamlWriter
from .resources import VocabulariesResourceConfig
from .services.config import VocabulariesServiceConfig

Expand Down Expand Up @@ -156,13 +156,13 @@
}
"""Vocabulary type search configuration."""

VOCABULARIES_ORCID_ACCESS_KEY="TOD"
VOCABULARIES_ORCID_ACCESS_KEY = "TOD"
"""ORCID access key to access the s3 bucket."""
VOCABULARIES_ORCID_SECRET_KEY="TODO"
VOCABULARIES_ORCID_SECRET_KEY = "TODO"
"""ORCID secret key to access the s3 bucket."""
VOCABULARIES_ORCID_SUMMARIES_BUCKET="v3.0-summaries"
VOCABULARIES_ORCID_SUMMARIES_BUCKET = "v3.0-summaries"
"""ORCID summaries bucket name."""
VOCABULARIES_ORCID_SYNC_MAX_WORKERS = 32
"""ORCID max number of simultaneous workers/connections."""
VOCABULARIES_ORCID_SYNC_DAYS = 1
"""ORCID number of days to sync."""
"""ORCID number of days to sync."""
74 changes: 43 additions & 31 deletions invenio_vocabularies/contrib/names/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@

"""Names datastreams, transformers, writers and readers."""

import io
import tarfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta

import s3fs
from flask import current_app
from invenio_records.dictutils import dict_lookup

from ...datastreams.errors import TransformerError
from ...datastreams.readers import SimpleHTTPReader, BaseReader
from ...datastreams.readers import BaseReader, SimpleHTTPReader
from ...datastreams.transformers import BaseTransformer
from ...datastreams.writers import ServiceWriter
import s3fs
from flask import current_app
from datetime import datetime
from datetime import timedelta
import tarfile
import io
from concurrent.futures import ThreadPoolExecutor, as_completed


class OrcidDataSyncReader(BaseReader):
"""ORCiD Data Sync Reader."""
Expand All @@ -29,67 +30,79 @@ def _fetch_orcid_data(self, orcid_to_sync, fs, bucket):
"""Fetches a single ORCiD record from S3."""
# The ORCiD file key is located in a folder which name corresponds to the last three digits of the ORCiD
suffix = orcid_to_sync[-3:]
key = f'{suffix}/{orcid_to_sync}.xml'
key = f"{suffix}/{orcid_to_sync}.xml"
try:
with fs.open(f's3://{bucket}/{key}', 'rb') as f:
with fs.open(f"s3://{bucket}/{key}", "rb") as f:
file_response = f.read()
return file_response
except Exception as e:
# TODO: log
return None

def _process_lambda_file(self, fileobj):
"""Process the ORCiD lambda file and returns a list of ORCiDs to sync.
The decoded fileobj looks like the following:
orcid,last_modified,created
0000-0001-5109-3700,2021-08-02 15:00:00.000,2021-08-02 15:00:00.000
Yield ORCiDs to sync until the last sync date is reached.
"""
date_format = '%Y-%m-%d %H:%M:%S.%f'
date_format_no_millis = '%Y-%m-%d %H:%M:%S'

last_sync = datetime.now() - timedelta(days=current_app.config["VOCABULARIES_ORCID_SYNC_DAYS"])

file_content = fileobj.read().decode('utf-8')

date_format = "%Y-%m-%d %H:%M:%S.%f"
date_format_no_millis = "%Y-%m-%d %H:%M:%S"

last_sync = datetime.now() - timedelta(
days=current_app.config["VOCABULARIES_ORCID_SYNC_DAYS"]
)

file_content = fileobj.read().decode("utf-8")

for line in file_content.splitlines()[1:]: # Skip the header line
elements = line.split(',')
elements = line.split(",")
orcid = elements[0]

# Lambda file is ordered by last modified date
last_modified_str = elements[3]
try:
last_modified_date = datetime.strptime(last_modified_str, date_format)
except ValueError:
last_modified_date = datetime.strptime(last_modified_str, date_format_no_millis)
last_modified_date = datetime.strptime(
last_modified_str, date_format_no_millis
)

if last_modified_date >= last_sync:
yield orcid
else:
break


def _iter(self, orcids, fs):
"""Iterates over the ORCiD records yielding each one."""

with ThreadPoolExecutor(max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"]) as executor:
futures = [executor.submit(self._fetch_orcid_data, orcid, fs, current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"]) for orcid in orcids]
with ThreadPoolExecutor(
max_workers=current_app.config["VOCABULARIES_ORCID_SYNC_MAX_WORKERS"]
) as executor:
futures = [
executor.submit(
self._fetch_orcid_data,
orcid,
fs,
current_app.config["VOCABULARIES_ORCID_SUMMARIES_BUCKET"],
)
for orcid in orcids
]
for future in as_completed(futures):
result = future.result()
if result is not None:
yield result


def read(self, item=None, *args, **kwargs):
"""Streams the ORCiD lambda file, process it to get the ORCiDS to sync and yields it's data."""
fs = s3fs.S3FileSystem(
key=current_app.config["VOCABULARIES_ORCID_ACCESS_KEY"],
secret=current_app.config["VOCABULARIES_ORCID_SECRET_KEY"]
secret=current_app.config["VOCABULARIES_ORCID_SECRET_KEY"],
)
# Read the file from S3
with fs.open('s3://orcid-lambda-file/last_modified.csv.tar', 'rb') as f:
with fs.open("s3://orcid-lambda-file/last_modified.csv.tar", "rb") as f:
tar_content = f.read()

orcids_to_sync = []
Expand All @@ -102,9 +115,8 @@ def read(self, item=None, *args, **kwargs):
if extracted_file:
# Process the file and get the ORCiDs to sync
orcids_to_sync.extend(self._process_lambda_file(extracted_file))

yield from self._iter(orcids_to_sync, fs)



class OrcidHTTPReader(SimpleHTTPReader):
Expand Down Expand Up @@ -207,7 +219,7 @@ def _entry_id(self, entry):
{
"type": "async",
"args": {
"writer":{
"writer": {
"type": "names-service",
}
},
Expand Down
36 changes: 20 additions & 16 deletions invenio_vocabularies/datastreams/datastreams.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,26 @@

"""Base data stream."""

from .errors import ReaderError, TransformerError, WriterError
from invenio_logging.structlog import LoggerFactory

from .errors import ReaderError, TransformerError, WriterError


class StreamEntry:
"""Object to encapsulate streams processing."""

def __init__(self, entry, errors=None, op_type=None):
"""Constructor for the StreamEntry class.
Args:
entry (object): The entry object, usually a record dict.
errors (list, optional): List of errors. Defaults to None.
op_type (str, optional): The operation type. Defaults to None.
"""
self.entry = entry
self.filtered = False
self.errors = errors or []
self.op_type = op_type
"""Constructor for the StreamEntry class.
:param entry (object): The entry object, usually a record dict.
:param errors (list, optional): List of errors. Defaults to None.
:param op_type (str, optional): The operation type. Defaults to None.
"""
self.entry = entry
self.filtered = False
self.errors = errors or []
self.op_type = op_type


class DataStream:
"""Data stream."""
Expand All @@ -44,7 +46,7 @@ def __init__(self, readers, writers, transformers=None, *args, **kwargs):
def filter(self, stream_entry, *args, **kwargs):
"""Checks if an stream_entry should be filtered out (skipped)."""
return False

def process_batch(self, batch, write_many=False):
transformed_entries = []
for stream_entry in batch:
Expand Down Expand Up @@ -77,9 +79,11 @@ def process(self, batch_size=100, write_many=False, logger=None, *args, **kwargs
"""
if not logger:
logger = LoggerFactory.get_logger("datastreams")

batch = []
logger.info(f"Start reading datastream with batch_size={batch_size} and write_many={write_many}")
logger.info(
f"Start reading datastream with batch_size={batch_size} and write_many={write_many}"
)
for stream_entry in self.read():
batch.append(stream_entry)
if len(batch) >= batch_size:
Expand Down Expand Up @@ -136,7 +140,7 @@ def write(self, stream_entry, *args, **kwargs):
stream_entry.errors.append(f"{writer.__class__.__name__}: {str(err)}")

return stream_entry

def batch_write(self, stream_entries, *args, **kwargs):
"""Apply the transformations to an stream_entry. Errors are handler in the service layer."""
for writer in self._writers:
Expand Down
1 change: 0 additions & 1 deletion invenio_vocabularies/datastreams/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from .errors import ReaderError
from .xml import etree_to_dict


try:
import oaipmh_scythe
except ImportError:
Expand Down
6 changes: 4 additions & 2 deletions invenio_vocabularies/datastreams/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
"""Data Streams Celery tasks."""

from celery import shared_task
from invenio_logging.structlog import LoggerFactory

from ..datastreams import StreamEntry
from ..datastreams.factories import WriterFactory
from invenio_logging.structlog import LoggerFactory


@shared_task(ignore_result=True)
def write_entry(writer_config, entry):
Expand All @@ -24,6 +25,7 @@ def write_entry(writer_config, entry):
writer = WriterFactory.create(config=writer_config)
writer.write(StreamEntry(entry))


@shared_task(ignore_result=True)
def write_many_entry(writer_config, entries, logger=None):
"""Write many entries.
Expand All @@ -41,4 +43,4 @@ def write_many_entry(writer_config, entries, logger=None):
logger.info("Entries written", succeeded=succeeded)
if errored:
for entry in errored:
logger.error("Error writing entry", entry=entry.entry, errors=entry.errors)
logger.error("Error writing entry", entry=entry.entry, errors=entry.errors)
19 changes: 13 additions & 6 deletions invenio_vocabularies/datastreams/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def write_many(self, stream_entries, *args, **kwargs):
"""
pass


class ServiceWriter(BaseWriter):
"""Writes the entries to an RDM instance using a Service object."""

Expand Down Expand Up @@ -98,17 +99,21 @@ def write(self, stream_entry, *args, **kwargs):
except InvalidRelationValue as err:
# TODO: Check if we can get the error message easier
raise WriterError([{"InvalidRelationValue": err.args[0]}])

def write_many(self, stream_entries, *args, **kwargs):
entries = [entry.entry for entry in stream_entries]
entries_with_id = [(self._entry_id(entry), entry) for entry in entries]
records = self._service.create_or_update_many(self._identity, entries_with_id)
stream_entries_processed= []
stream_entries_processed = []
for op_type, record, errors in records:
if errors == []:
stream_entries_processed.append(StreamEntry(entry=record, op_type=op_type))
stream_entries_processed.append(
StreamEntry(entry=record, op_type=op_type)
)
else:
stream_entries_processed.append(StreamEntry(entry=record, errors=errors, op_type=op_type))
stream_entries_processed.append(
StreamEntry(entry=record, errors=errors, op_type=op_type)
)

return stream_entries_processed

Expand Down Expand Up @@ -154,6 +159,8 @@ def write(self, stream_entry, *args, **kwargs):

def write_many(self, stream_entries, *args, **kwargs):
"""Launches a celery task to write an entry."""
write_many_entry.delay(self._writer, [stream_entry.entry for stream_entry in stream_entries])
write_many_entry.delay(
self._writer, [stream_entry.entry for stream_entry in stream_entries]
)

return stream_entries
return stream_entries
Loading

0 comments on commit 9b4f9ca

Please sign in to comment.