From 98b736ec9d23ee140259b63b55ebd534531c4bf4 Mon Sep 17 00:00:00 2001 From: Simen Fivelstad Smaaberg <66635118+simensma-fresh@users.noreply.github.com> Date: Fri, 13 Dec 2024 13:19:28 -0800 Subject: [PATCH] [MDS-6260] Added support to permit condition extraction pipeline to create identified reports (#3339) * [MDS-6260] Added support to permit condition extraction pipeline to create identified reports * Use auto-reloadable celery for core api worker * Add meta column to permit_conditions_version * Update fastapi version * Added testkey for documentintelligence * MDS-6260 Cleanup --- .github/workflows/permit-service.unit.yaml | 4 +- docker-compose.yaml | 2 +- ...13.1__add_permit_condition_meta_column.sql | 1 + ...e_report_definition_report_name_column.sql | 1 + ...d_permit_condition_history_meta_column.sql | 1 + services/core-api/Dockerfile | 1 + .../models/permit_conditions.py | 157 +++++++---- ...ate_permit_condition_report_requirement.py | 141 ++++++++++ .../create_permit_conditions.py | 213 +++++++++----- .../models/permit_condition_result.py | 1 + .../models/mine_report_permit_requirement.py | 9 +- .../mine_report_permit_requirement.py | 77 ++--- .../core-api/app/api/mines/response_models.py | 4 +- services/core-api/celery_dev.sh | 18 ++ ...ate_permit_condition_report_requirement.py | 120 ++++++++ ...test_create_permit_conditions_from_task.py | 262 ++++++++++++++++-- services/permits/.env-example | 2 +- .../permits/app/permit_condition_prompts.yaml | 23 +- .../azure_document_intelligence_converter.py | 30 +- .../filter_conditions_paragraphs.py | 3 - .../converters/metadata_converter.py | 29 +- .../converters/pdf_to_text_converter.py | 1 - .../CachedAzureOpenAIChatGenerator.py | 184 ++++++++---- .../pipelines/PaginatedChatPromptBuilder.py | 98 ++++++- .../permit_conditions/pipelines/chat_data.py | 4 +- .../pipelines/permit_condition_pipeline.py | 16 +- .../resources/permit_condition_resource.py | 2 +- .../permit_conditions/validator/json_fixer.py | 18 +- .../validator/permit_condition_model.py | 33 +++ .../permit_condition_section_combiner.py | 10 +- .../validator/permit_condition_validator.py | 2 +- services/permits/requirements.txt | 2 +- ...t_azure_document_intelligence_converter.py | 20 +- ...test_cached_azure_openai_chat_generator.py | 121 +++++--- .../test_conditions_metadata_combiner.py | 113 +++++++- services/permits/tests/test_json_fixer.py | 82 +++++- .../test_paginated_chat_prompt_builder.py | 189 +++++++++++++ .../tests/test_permit_condition_validator.py | 9 +- 38 files changed, 1583 insertions(+), 420 deletions(-) create mode 100644 migrations/sql/V2024.12.03.13.1__add_permit_condition_meta_column.sql create mode 100644 migrations/sql/V2024.12.06.11.1__add_mine_report_definition_report_name_column.sql create mode 100644 migrations/sql/V2024.12.11.15.1__add_permit_condition_history_meta_column.sql create mode 100644 services/core-api/app/api/mines/permits/permit_extraction/create_permit_condition_report_requirement.py create mode 100755 services/core-api/celery_dev.sh create mode 100644 services/core-api/tests/permits/permit_extraction/test_create_permit_condition_report_requirement.py create mode 100644 services/permits/tests/test_paginated_chat_prompt_builder.py diff --git a/.github/workflows/permit-service.unit.yaml b/.github/workflows/permit-service.unit.yaml index 729b12ab61..b351d09fe2 100644 --- a/.github/workflows/permit-service.unit.yaml +++ b/.github/workflows/permit-service.unit.yaml @@ -28,8 +28,8 @@ jobs: env: DOCKER_BUILDKIT: 1 run: | - docker compose -f docker-compose.yaml run -e AZURE_API_KEY=testkey haystack coverage run -m pytest - docker compose -f docker-compose.yaml run -e AZURE_API_KEY=testkey haystack coverage xml + docker compose -f docker-compose.yaml run -e AZURE_API_KEY=testkey -e DOCUMENTINTELLIGENCE_API_KEY=testkey haystack coverage run -m pytest + docker compose -f docker-compose.yaml run -e AZURE_API_KEY=testkey -e DOCUMENTINTELLIGENCE_API_KEY=testkey haystack coverage xml sed -i "s/\/code/\/github\/workspace\/services\/permits/g" services/permits/coverage.xml - name: Upload test coverage results uses: actions/upload-artifact@v4 diff --git a/docker-compose.yaml b/docker-compose.yaml index d886b4164b..7f40e59421 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -134,7 +134,7 @@ services: container_name: core_api_celery build: context: services/core-api - entrypoint: ./celery.sh + entrypoint: ./celery_dev.sh ports: - 5556:5555 volumes: diff --git a/migrations/sql/V2024.12.03.13.1__add_permit_condition_meta_column.sql b/migrations/sql/V2024.12.03.13.1__add_permit_condition_meta_column.sql new file mode 100644 index 0000000000..e02455da4f --- /dev/null +++ b/migrations/sql/V2024.12.03.13.1__add_permit_condition_meta_column.sql @@ -0,0 +1 @@ +ALTER TABLE permit_conditions ADD COLUMN IF NOT EXISTS meta JSONB NULL; \ No newline at end of file diff --git a/migrations/sql/V2024.12.06.11.1__add_mine_report_definition_report_name_column.sql b/migrations/sql/V2024.12.06.11.1__add_mine_report_definition_report_name_column.sql new file mode 100644 index 0000000000..fa7811efa3 --- /dev/null +++ b/migrations/sql/V2024.12.06.11.1__add_mine_report_definition_report_name_column.sql @@ -0,0 +1 @@ +ALTER TABLE mine_report_permit_requirement ADD COLUMN report_name VARCHAR(255) NULL; \ No newline at end of file diff --git a/migrations/sql/V2024.12.11.15.1__add_permit_condition_history_meta_column.sql b/migrations/sql/V2024.12.11.15.1__add_permit_condition_history_meta_column.sql new file mode 100644 index 0000000000..308ea64e79 --- /dev/null +++ b/migrations/sql/V2024.12.11.15.1__add_permit_condition_history_meta_column.sql @@ -0,0 +1 @@ +ALTER TABLE permit_conditions_version ADD COLUMN IF NOT EXISTS meta JSONB NULL; \ No newline at end of file diff --git a/services/core-api/Dockerfile b/services/core-api/Dockerfile index 09f1ddc377..2396cc900d 100644 --- a/services/core-api/Dockerfile +++ b/services/core-api/Dockerfile @@ -9,6 +9,7 @@ RUN apt-get install build-essential -y # Install the requirements COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt +RUN pip install watchdog COPY . . diff --git a/services/core-api/app/api/mines/permits/permit_conditions/models/permit_conditions.py b/services/core-api/app/api/mines/permits/permit_conditions/models/permit_conditions.py index a4312a69d8..bd5c41366c 100644 --- a/services/core-api/app/api/mines/permits/permit_conditions/models/permit_conditions.py +++ b/services/core-api/app/api/mines/permits/permit_conditions/models/permit_conditions.py @@ -1,54 +1,68 @@ -import uuid -from datetime import datetime - from app.api.utils.field_template import FieldTemplate from app.api.utils.list_lettering_helpers import num_to_letter, num_to_roman from app.api.utils.models_mixins import AuditMixin, Base, SoftDeleteMixin from app.extensions import db -from marshmallow import fields, validate -from sqlalchemy.dialects.postgresql import UUID +from marshmallow import fields +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import backref, validates +from sqlalchemy.orm import backref from sqlalchemy.schema import FetchedValue class PermitConditions(SoftDeleteMixin, AuditMixin, Base): - __tablename__ = 'permit_conditions' + __tablename__ = "permit_conditions" class _ModelSchema(Base._ModelSchema): permit_condition_id = fields.Integer(dump_only=True) permit_condition_guid = fields.UUID(dump_only=True) condition_category_code = FieldTemplate( - field=fields.String, one_of='PermitConditionCategory') - condition_type_code = FieldTemplate(field=fields.String, one_of='PermitConditionType') + field=fields.String, one_of="PermitConditionCategory" + ) + condition_type_code = FieldTemplate( + field=fields.String, one_of="PermitConditionType" + ) permit_condition_id = db.Column(db.Integer, primary_key=True) permit_amendment_id = db.Column( - db.Integer, db.ForeignKey('permit_amendment.permit_amendment_id'), nullable=False) - permit_amendment = db.relationship('PermitAmendment', lazy='select', back_populates='conditions') + db.Integer, + db.ForeignKey("permit_amendment.permit_amendment_id"), + nullable=False, + ) + permit_amendment = db.relationship( + "PermitAmendment", lazy="select", back_populates="conditions" + ) permit_condition_guid = db.Column(UUID(as_uuid=True), server_default=FetchedValue()) condition = db.Column(db.String, nullable=False) condition_category_code = db.Column( db.String, - db.ForeignKey('permit_condition_category.condition_category_code'), - nullable=False) - - condition_category = db.relationship('PermitConditionCategory', lazy='select') + db.ForeignKey("permit_condition_category.condition_category_code"), + nullable=False, + ) + + condition_category = db.relationship("PermitConditionCategory", lazy="select") condition_type_code = db.Column( - db.String, db.ForeignKey('permit_condition_type.condition_type_code'), nullable=False) - parent_permit_condition_id = db.Column(db.Integer, - db.ForeignKey('permit_conditions.permit_condition_id')) + db.String, + db.ForeignKey("permit_condition_type.condition_type_code"), + nullable=False, + ) + parent_permit_condition_id = db.Column( + db.Integer, db.ForeignKey("permit_conditions.permit_condition_id") + ) display_order = db.Column(db.Integer, nullable=False) - _step = db.Column('step', db.String, nullable=True) + + meta = db.Column(JSONB(astext_type=db.Text()), nullable=True) + + _step = db.Column("step", db.String, nullable=True) __versioned__ = {} all_sub_conditions = db.relationship( - 'PermitConditions', - lazy='joined', - order_by='asc(PermitConditions.display_order)', - backref=backref('parent', remote_side=[permit_condition_id])) + "PermitConditions", + lazy="joined", + order_by="asc(PermitConditions.display_order)", + backref=backref("parent", remote_side=[permit_condition_id]), + ) @hybrid_property def sub_conditions(self): @@ -68,64 +82,80 @@ def step(self): if self._step: # Format the first level with a trailing dot - A. B. C. and the rest with () - (a), (i), (ii) if depth == 0: - return f'{self._step}.' - return f'({self._step})' - if self._step == '': - return '' + return f"{self._step}." + return f"({self._step})" + if self._step == "": + return "" step_format = depth % 3 if step_format == 0: - return str(self.display_order) + '.' + return str(self.display_order) + "." elif step_format == 1: - return num_to_letter(self.display_order) + '.' + return num_to_letter(self.display_order) + "." elif step_format == 2: - return num_to_roman(self.display_order) + '.' + return num_to_roman(self.display_order) + "." def __repr__(self): - return '' % (self.permit_condition_id, - self.permit_condition_guid, self.display_order) + return "" % ( + self.permit_condition_id, + self.permit_condition_guid, + self.display_order, + ) @classmethod - def create(cls, - condition_category_code, - condition_type_code, - permit_amendment_id, - condition, - display_order, - sub_conditions, - parent=None): + def create( + cls, + condition_category_code, + condition_type_code, + permit_amendment_id, + condition, + display_order, + sub_conditions, + parent=None, + ): permit_condition = cls( condition_category_code=condition_category_code, condition_type_code=condition_type_code, permit_amendment_id=permit_amendment_id, condition=condition, display_order=display_order, - parent=parent) + parent=parent, + ) permit_condition.save(commit=False) for condition in sub_conditions: - PermitConditions.create(condition.condition_category_code, - condition.condition_type_code, permit_amendment_id, - condition.condition, condition.display_order, - condition.sub_conditions, permit_condition) + PermitConditions.create( + condition.condition_category_code, + condition.condition_type_code, + permit_amendment_id, + condition.condition, + condition.display_order, + condition.sub_conditions, + permit_condition, + ) return permit_condition - @classmethod def delete_all_by_permit_amendment_id(cls, permit_amendment_id, commit=False): - parent_conditions = cls.query.filter_by( - permit_amendment_id=permit_amendment_id, - parent_permit_condition_id=None, - deleted_ind=False).order_by(cls.display_order).all() + parent_conditions = ( + cls.query.filter_by( + permit_amendment_id=permit_amendment_id, + parent_permit_condition_id=None, + deleted_ind=False, + ) + .order_by(cls.display_order) + .all() + ) for condition in parent_conditions: condition.delete_condition(commit=commit) if commit: condition.save() - def delete_condition(self, commit=False): if self.all_sub_conditions is not None: - subconditions = [c for c in self.all_sub_conditions if c.deleted_ind == False] + subconditions = [ + c for c in self.all_sub_conditions if c.deleted_ind == False + ] if len(subconditions) > 0: for item in subconditions: item.deleted_ind = True @@ -134,25 +164,32 @@ def delete_condition(self, commit=False): item.save() self.deleted_ind = True - @classmethod def find_all_by_permit_amendment_id(cls, permit_amendment_id): - return cls.query.filter_by( - permit_amendment_id=permit_amendment_id, - parent_permit_condition_id=None, - deleted_ind=False).order_by(cls.display_order).all() + return ( + cls.query.filter_by( + permit_amendment_id=permit_amendment_id, + parent_permit_condition_id=None, + deleted_ind=False, + ) + .order_by(cls.display_order) + .all() + ) @classmethod def find_by_permit_condition_guid(cls, permit_condition_guid): return cls.query.filter_by( - permit_condition_guid=permit_condition_guid, deleted_ind=False).first() + permit_condition_guid=permit_condition_guid, deleted_ind=False + ).first() @classmethod def find_by_permit_condition_id(cls, permit_condition_id): return cls.query.filter_by( - permit_condition_id=permit_condition_id, deleted_ind=False).first() + permit_condition_id=permit_condition_id, deleted_ind=False + ).first() @classmethod def find_by_condition_category_code(cls, condition_category_code): return cls.query.filter_by( - condition_category_code=condition_category_code, deleted_ind=False).all() + condition_category_code=condition_category_code, deleted_ind=False + ).all() diff --git a/services/core-api/app/api/mines/permits/permit_extraction/create_permit_condition_report_requirement.py b/services/core-api/app/api/mines/permits/permit_extraction/create_permit_condition_report_requirement.py new file mode 100644 index 0000000000..e24f32ed19 --- /dev/null +++ b/services/core-api/app/api/mines/permits/permit_extraction/create_permit_condition_report_requirement.py @@ -0,0 +1,141 @@ +from typing import Optional + +from app.api.mines.reports.models.mine_report_permit_requirement import ( + MineReportPermitRequirement, +) +from dateutil.parser import parse +from flask import current_app + +from .models.permit_condition_result import PermitConditionResult + + +def create_permit_condition_report_requirement( + task, condition: PermitConditionResult, condition_id +) -> Optional[MineReportPermitRequirement]: + """ + Creates a MineReportPermitRequirement based on permit condition details. + Args: + task: Task object containing permit amendment information + condition (PermitConditionResult): Permit condition details + condition_id: ID of the permit condition + Returns: + MineReportPermitRequirement: Created requirement object, or None if report not required + """ + + meta = condition.meta or {} + questions = meta.get("questions", []) + + require_report = False + recurring = False + frequency = None + mention_chief_inspector = False + mention_chief_permitting_officer = False + initial_due_date = None + report_name = None + + for q in questions: + key = q.get("question_key") + answer = q.get("answer") + if key == "require_report": + require_report = answer + elif key == "due_date": + initial_due_date = answer # Parse date if necessary + elif key == "recurring": + recurring = answer + elif key == "frequency": + frequency = answer + elif key == "mention_chief_inspector": + mention_chief_inspector = answer + elif key == "mention_chief_permitting_officer": + mention_chief_permitting_officer = answer + elif key == "report_name": + report_name = answer + + if not require_report: + return None + + initial_due_date = _parse_initial_due_date(condition_id, initial_due_date) + + # Determine cim_or_cpo based on mentions + cim_or_cpo = _parse_cim_cpo( + mention_chief_inspector, mention_chief_permitting_officer + ) + + # Calculate due_date_period_months based on frequency + due_date_period_months = _parse_due_date_period(recurring, frequency) + + # Create the MineReportPermitRequirement + mine_report_permit_requirement = MineReportPermitRequirement( + report_name=report_name, + permit_condition_id=condition_id, + permit_amendment_id=task.permit_amendment.permit_amendment_id, + cim_or_cpo=cim_or_cpo, + due_date_period_months=due_date_period_months or 0, + initial_due_date=initial_due_date, + ministry_recipient=None, # Not specified in permits themselves. + ) + + return mine_report_permit_requirement + + +def _parse_due_date_period(recurring, frequency): + due_date_period_months = None + if recurring and frequency: + frequency_mapping = { + "monthly": 1, + "per month": 1, + "every month": 1, + "quarterly": 3, + "every quarter": 3, + "semiannually": 6, + "semiannual": 6, + "every six months": 6, + "twice yearly": 6, + "annually": 12, + "yearly": 12, + "annual": 12, + "per year": 12, + "every year": 12, + "biannually": 24, + "biannual": 24, + "asneeded": 0, + "as needed": 0, + "as required": 0, + "5 years": 60, + "every 5 years": 60, + "five years": 60, + "every five years": 60, + } + # Clean frequency string by removing spaces and special characters before lookup + frequency = "".join( + e for e in frequency.lower() if e.isalnum() or e.isspace() + ).strip() + frequency = " ".join(frequency.split()) + due_date_period_months = frequency_mapping.get(frequency) + return due_date_period_months + + +def _parse_cim_cpo(mention_chief_inspector, mention_chief_permitting_officer): + cim_or_cpo = None + if mention_chief_inspector and mention_chief_permitting_officer: + cim_or_cpo = "BOTH" + elif mention_chief_inspector: + cim_or_cpo = "CIM" + elif mention_chief_permitting_officer: + cim_or_cpo = "CPO" + return cim_or_cpo + + +def _parse_initial_due_date(condition_id, initial_due_date): + if initial_due_date == "": + initial_due_date = None + + if initial_due_date: + try: + initial_due_date = parse(initial_due_date) + except ValueError: + current_app.logger.error( + f"Could not parse due date for condition {condition_id}: {initial_due_date}" + ) + initial_due_date = None + return initial_due_date diff --git a/services/core-api/app/api/mines/permits/permit_extraction/create_permit_conditions.py b/services/core-api/app/api/mines/permits/permit_extraction/create_permit_conditions.py index 8cddcf9180..f72b6e70a9 100644 --- a/services/core-api/app/api/mines/permits/permit_extraction/create_permit_conditions.py +++ b/services/core-api/app/api/mines/permits/permit_extraction/create_permit_conditions.py @@ -1,6 +1,5 @@ import uuid -from difflib import SequenceMatcher -from typing import List, Optional +from typing import Optional from app.api.mines.permits.permit_amendment.models.permit_amendment import ( PermitAmendment, @@ -17,6 +16,9 @@ from app.extensions import db from flask import current_app +from .create_permit_condition_report_requirement import ( + create_permit_condition_report_requirement, +) from .models.permit_condition_result import ( CreatePermitConditionsResult, PermitConditionResult, @@ -24,15 +26,16 @@ indentation_type_code_mapping = { 0: None, - 1: 'SEC', - 2: 'CON', - 3: 'LIS', - 4: 'LIS', - 5: 'LIS', + 1: "SEC", + 2: "CON", + 3: "LIS", + 4: "LIS", + 5: "LIS", } # For conditions that don't match any category, put them in a "Terms and conditions" category -DEFAULT_CATEGORY_TEXT = 'Terms and Conditions' +DEFAULT_CATEGORY_TEXT = "Terms and Conditions" + def create_permit_conditions_from_task(task: PermitExtractionTask): """ @@ -40,67 +43,93 @@ def create_permit_conditions_from_task(task: PermitExtractionTask): """ result = task.task_result last_condition_id_by_hierarchy = {} + display_order_by_parent = {} current_category = None + try: + result = CreatePermitConditionsResult.model_validate(result) - result = CreatePermitConditionsResult.model_validate(result) - - has_category = any([condition.is_top_level_section for condition in result.conditions]) - - conditions = result.conditions - if not has_category: - top_level_section = PermitConditionResult( - section='A', - condition_text=DEFAULT_CATEGORY_TEXT + has_category = any( + [condition.is_top_level_section for condition in result.conditions] ) - for c in conditions: - c.set_section(top_level_section) - conditions = [top_level_section] + conditions - num_categories = 0 + conditions = result.conditions + if not has_category: + top_level_section = PermitConditionResult( + section="A", condition_text=DEFAULT_CATEGORY_TEXT + ) + for c in conditions: + c.set_section(top_level_section) + conditions = [top_level_section] + conditions - default_section = None + num_categories = 0 - for idx, condition in enumerate(conditions): - if condition.is_top_level_section: - section_category = _create_permit_condition_category( - condition=condition, - permit_amendment=task.permit_amendment, - display_order=num_categories, - step=condition.step - ) - if condition.condition_text == DEFAULT_CATEGORY_TEXT: - default_section = section_category - current_category = section_category - num_categories += 1 - else: - parent = _determine_parent(condition, last_condition_id_by_hierarchy) - type_code = _map_condition_to_type_code(condition) - - title_cond = None - - if not current_category and not default_section: - default_section = _create_permit_condition_category( - condition=PermitConditionResult( - section='A', - condition_text=DEFAULT_CATEGORY_TEXT - ), + default_section = None + + for idx, condition in enumerate(conditions): + if condition.is_top_level_section: + section_category = _create_permit_condition_category( + condition=condition, permit_amendment=task.permit_amendment, display_order=num_categories, - step='A' + step=condition.step, + ) + if condition.condition_text == DEFAULT_CATEGORY_TEXT: + default_section = section_category + current_category = section_category + num_categories += 1 + else: + parent = _determine_parent(condition, last_condition_id_by_hierarchy) + type_code = _map_condition_to_type_code(condition) + + parent_id = parent.permit_condition_id if parent else None + + if parent_id not in display_order_by_parent: + display_order_by_parent[parent_id] = 0 + display_order_by_parent[parent_id] += 1 + current_display_order = display_order_by_parent[parent_id] + + title_cond = None + + if not current_category and not default_section: + default_section = _create_permit_condition_category( + condition=PermitConditionResult( + section="A", condition_text=DEFAULT_CATEGORY_TEXT + ), + permit_amendment=task.permit_amendment, + display_order=num_categories, + step="A", + ) + + category_code = current_category or default_section + + if condition.condition_title: + title_cond = _create_title_condition( + task, + category_code, + condition, + parent, + current_display_order, + type_code, + ) + + parent_condition_id = _get_parent_condition_id(title_cond, parent) + cond = _create_permit_condition( + task, + category_code, + condition, + parent_condition_id, + current_display_order, + type_code, ) - category_code = current_category or default_section - if condition.condition_title: - title_cond = _create_title_condition(task, category_code, condition, parent, idx, type_code) - - parent_condition_id = _get_parent_condition_id(title_cond, parent) - cond = _create_permit_condition(task, category_code, condition, parent_condition_id, idx, type_code) + hierarchy_key = ".".join(condition.numbering_structure) + last_condition_id_by_hierarchy[hierarchy_key] = cond - hierarchy_key = ".".join(condition.numbering_structure) - last_condition_id_by_hierarchy[hierarchy_key] = cond - db.session.commit() + db.session.commit() + except: + db.session.rollback() + raise - def _map_condition_to_type_code(condition: PermitConditionResult): """ @@ -109,15 +138,22 @@ def _map_condition_to_type_code(condition: PermitConditionResult): Example: ['A', '1', '', '', ''] would have an indentation of 2 -> type code is 'CON' Example: ['A', '', '', '', ''] would have an indentation of 1 -> type code is 'SEC' """ - indentation = next((i-1 for i, x in enumerate(condition.numbering_structure) if x == ''), 0) + indentation = next( + (i - 1 for i, x in enumerate(condition.numbering_structure) if x == ""), 0 + ) type_code = indentation_type_code_mapping[indentation] - + if not type_code: - current_app.logger.error(f"Could not determine type code for condition {condition}") + current_app.logger.error( + f"Could not determine type code for condition {condition}" + ) - return type_code or 'LIS' + return type_code or "LIS" -def _create_title_condition(task, current_category, condition, parent, idx, type_code) -> PermitConditionResult: + +def _create_title_condition( + task, current_category, condition, parent, idx, type_code +) -> PermitConditionResult: condition = PermitConditions( permit_amendment_id=task.permit_amendment.permit_amendment_id, permit_condition_guid=uuid.uuid4(), @@ -126,6 +162,7 @@ def _create_title_condition(task, current_category, condition, parent, idx, type condition_type_code=type_code, parent_permit_condition_id=parent.permit_condition_id if parent else None, display_order=idx, + meta=condition.meta, _step=condition.step, ) @@ -133,7 +170,10 @@ def _create_title_condition(task, current_category, condition, parent, idx, type db.session.flush() # This assigns an ID to title_cond without committing the transaction return condition -def _get_parent_condition_id(title_cond: PermitConditionResult, parent: PermitConditionResult) -> Optional[str]: + +def _get_parent_condition_id( + title_cond: PermitConditionResult, parent: PermitConditionResult +) -> Optional[str]: if title_cond: # If the condition has a title, the parent is the title condition return title_cond.permit_condition_id @@ -142,7 +182,10 @@ def _get_parent_condition_id(title_cond: PermitConditionResult, parent: PermitCo else: return None -def _create_permit_condition(task, current_category, condition, parent_condition_id, idx, type_code) -> PermitConditions: + +def _create_permit_condition( + task, current_category, condition, parent_condition_id, idx, type_code +) -> PermitConditions: condition = PermitConditions( permit_amendment_id=task.permit_amendment.permit_amendment_id, permit_condition_guid=uuid.uuid4(), @@ -151,15 +194,29 @@ def _create_permit_condition(task, current_category, condition, parent_condition condition_type_code=type_code, parent_permit_condition_id=parent_condition_id, display_order=idx, - _step=condition.step if not condition.condition_title else '', # If the condition has a title, the parent is the title condition, which has the numbering associated with it already + meta=condition.meta, + _step=( + condition.step if not condition.condition_title else "" + ), # If the condition has a title, the parent is the title condition, which has the numbering associated with it already ) db.session.add(condition) + db.session.flush() # This assigns an ID to cond without committing the transaction + report_requirement = create_permit_condition_report_requirement( + task, condition, condition.permit_condition_id + ) + + if report_requirement: + db.session.add(report_requirement) + db.session.flush() + return condition -def _determine_parent(condition: PermitConditionResult, last_condition_id_by_number_structure) -> Optional[PermitConditionResult]: +def _determine_parent( + condition: PermitConditionResult, last_condition_id_by_number_structure +) -> Optional[PermitConditionResult]: """ Determine the parent ID based on the hierarchy. @@ -170,14 +227,22 @@ def _determine_parent(condition: PermitConditionResult, last_condition_id_by_num parent_number_structure = [item for item in number_structure if item][:-1] if len(parent_number_structure) < len(number_structure): - parent_number_structure += [''] * (len(number_structure) - len(parent_number_structure)) + parent_number_structure += [""] * ( + len(number_structure) - len(parent_number_structure) + ) parent_key = ".".join(parent_number_structure) parent = last_condition_id_by_number_structure.get(parent_key) return parent -def _create_permit_condition_category(condition: PermitConditionResult, permit_amendment: PermitAmendment, display_order: int, step: str) -> Optional[str]: + +def _create_permit_condition_category( + condition: PermitConditionResult, + permit_amendment: PermitAmendment, + display_order: int, + step: str, +) -> Optional[str]: """ Finds the matching PermitConditionCategory code for the given condition based on the title or text it contains. @@ -192,19 +257,21 @@ def _create_permit_condition_category(condition: PermitConditionResult, permit_a Args: condition_categories: List of PermitConditionCategory objects condition: Condition object - + """ - text = condition.condition_title if condition.condition_title else condition.condition_text + text = ( + condition.condition_title + if condition.condition_title + else condition.condition_text + ) cat = PermitConditionCategory.create( condition_category_code=str(uuid.uuid4()), description=text, display_order=display_order, permit_amendment_id=permit_amendment.permit_amendment_id, - step=step + step=step, ) return cat.condition_category_code - - diff --git a/services/core-api/app/api/mines/permits/permit_extraction/models/permit_condition_result.py b/services/core-api/app/api/mines/permits/permit_extraction/models/permit_condition_result.py index cb8a80a32e..7a0d165f0a 100644 --- a/services/core-api/app/api/mines/permits/permit_extraction/models/permit_condition_result.py +++ b/services/core-api/app/api/mines/permits/permit_extraction/models/permit_condition_result.py @@ -12,6 +12,7 @@ class PermitConditionResult(BaseModel): subsubclause: Optional[str] = None condition_title: Optional[str] = None condition_text: str + meta: Optional[dict] = None @computed_field def numbering_structure(self) -> List[str]: diff --git a/services/core-api/app/api/mines/reports/models/mine_report_permit_requirement.py b/services/core-api/app/api/mines/reports/models/mine_report_permit_requirement.py index cf2b336740..eb319232d1 100644 --- a/services/core-api/app/api/mines/reports/models/mine_report_permit_requirement.py +++ b/services/core-api/app/api/mines/reports/models/mine_report_permit_requirement.py @@ -2,12 +2,11 @@ from enum import Enum from typing import Optional +from app.api.utils.models_mixins import AuditMixin, Base, SoftDeleteMixin +from app.extensions import db from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.schema import FetchedValue -from app.api.utils.models_mixins import Base, AuditMixin, SoftDeleteMixin -from app.extensions import db - class CimOrCpo(str, Enum): CIM = "CIM" @@ -34,6 +33,8 @@ class MineReportPermitRequirement(SoftDeleteMixin, Base, AuditMixin): mine_report_permit_requirement_id: int = db.Column(db.Integer, primary_key=True, server_default=FetchedValue()) due_date_period_months: int = db.Column(db.Integer, nullable=False) initial_due_date: Optional[date] = db.Column(db.Date, nullable=True) + report_name = db.Column(db.String(512), nullable=True) + active_ind: bool = db.Column(db.Boolean, nullable=False, server_default=FetchedValue()) cim_or_cpo: Optional[CimOrCpo] = db.Column(db.Enum(CimOrCpo, name='cim_or_cpo_type'), nullable=True) ministry_recipient: Optional[list[OfficeDestination]] = db.Column( @@ -67,6 +68,7 @@ def get_all(cls) -> list["MineReportPermitRequirement"]: @classmethod def create(cls, + report_name: Optional[str], due_date_period_months: int, initial_due_date: date, cim_or_cpo: Optional[CimOrCpo], @@ -75,6 +77,7 @@ def create(cls, permit_amendment_id: int) -> "MineReportPermitRequirement": mine_report_permit_requirement = cls( + report_name=report_name, due_date_period_months=due_date_period_months, initial_due_date=initial_due_date, cim_or_cpo=cim_or_cpo, diff --git a/services/core-api/app/api/mines/reports/resources/mine_report_permit_requirement.py b/services/core-api/app/api/mines/reports/resources/mine_report_permit_requirement.py index 323db4b4cd..2388e06b2c 100644 --- a/services/core-api/app/api/mines/reports/resources/mine_report_permit_requirement.py +++ b/services/core-api/app/api/mines/reports/resources/mine_report_permit_requirement.py @@ -1,74 +1,85 @@ from datetime import datetime -from flask import current_app -from werkzeug.exceptions import NotFound, BadRequest - from app.api.mines.mine.models.mine import Mine -from app.api.mines.permits.permit.models.permit import Permit -from app.api.mines.permits.permit_amendment.models.permit_amendment import PermitAmendment +from app.api.mines.permits.permit_amendment.models.permit_amendment import ( + PermitAmendment, +) from app.api.mines.permits.permit_conditions.models import PermitConditions +from app.api.mines.reports.models.mine_report_permit_requirement import ( + CimOrCpo, + MineReportPermitRequirement, +) from app.api.mines.response_models import MINE_REPORT_PERMIT_REQUIREMENT -from app.api.utils.access_decorators import requires_any_of, EDIT_REPORT -from app.extensions import api - -from flask_restx import Resource - -from app.api.mines.reports.models.mine_report_permit_requirement import CimOrCpo, MineReportPermitRequirement +from app.api.utils.access_decorators import EDIT_REPORT, requires_any_of from app.api.utils.custom_reqparser import CustomReqparser from app.api.utils.resources_mixins import UserMixin +from app.extensions import api +from flask import current_app +from flask_restx import Resource +from werkzeug.exceptions import BadRequest, NotFound class MineReportPermitRequirementResource(Resource, UserMixin): parser = CustomReqparser() - parser.add_argument('due_date_period_months', type=int, location='json') - parser.add_argument('initial_due_date', type=lambda x: datetime.strptime(x, '%Y-%m-%d') if x else None, location='json') - parser.add_argument('cim_or_cpo', type=str, location='json') - parser.add_argument('ministry_recipient', type=list, location='json') - parser.add_argument('permit_condition_id', type=int, location='json') - parser.add_argument('permit_amendment_id', type=int, location='json') + parser.add_argument("due_date_period_months", type=int, location="json") + parser.add_argument( + "initial_due_date", + type=lambda x: datetime.strptime(x, "%Y-%m-%d") if x else None, + location="json", + ) + parser.add_argument("cim_or_cpo", type=str, location="json") + parser.add_argument("ministry_recipient", type=list, location="json") + parser.add_argument("permit_condition_id", type=int, location="json") + parser.add_argument("permit_amendment_id", type=int, location="json") @api.expect(parser) - @api.doc(description='creates a new mine report permit requirement') + @api.doc(description="creates a new mine report permit requirement") @api.marshal_with(MINE_REPORT_PERMIT_REQUIREMENT, code=201) @requires_any_of([EDIT_REPORT]) def post(self, mine_guid): - current_app.logger.debug('CREATING REQUIREMENT') + current_app.logger.debug("CREATING REQUIREMENT") data = self.parser.parse_args() mine = Mine.find_by_mine_guid(mine_guid) if not mine: - raise NotFound('Mine not found') + raise NotFound("Mine not found") - permit_amendment_id = data.get('permit_amendment_id') - permit_amendment = PermitAmendment.find_by_permit_amendment_id(permit_amendment_id) + permit_amendment_id = data.get("permit_amendment_id") + permit_amendment = PermitAmendment.find_by_permit_amendment_id( + permit_amendment_id + ) if permit_amendment is None: - raise NotFound('Permit not found') + raise NotFound("Permit not found") if permit_amendment: permit_amendment._context_mine = mine if permit_amendment.mine_guid != mine.mine_guid: - raise BadRequest('The permit must be associated with the selected mine.') + raise BadRequest( + "The permit must be associated with the selected mine." + ) - permit_condition_id = data.get('permit_condition_id') - permit_condition = PermitConditions.find_by_permit_condition_id(permit_condition_id) + permit_condition_id = data.get("permit_condition_id") + permit_condition = PermitConditions.find_by_permit_condition_id( + permit_condition_id + ) if permit_condition is None: - raise NotFound('Permit Condition not found') + raise NotFound("Permit Condition not found") - cim_or_cpo = data.get('cim_or_cpo') - if cim_or_cpo == 'NONE': + cim_or_cpo = data.get("cim_or_cpo") + if cim_or_cpo == "NONE": cim_or_cpo = None else: cim_or_cpo = CimOrCpo(cim_or_cpo) mine_report_permit_requirement = MineReportPermitRequirement.create( - due_date_period_months=data.get('due_date_period_months'), - initial_due_date=data.get('initial_due_date'), + report_name=data.get("report_name"), + due_date_period_months=data.get("due_date_period_months"), + initial_due_date=data.get("initial_due_date"), cim_or_cpo=cim_or_cpo, - ministry_recipient=data.get('ministry_recipient'), + ministry_recipient=data.get("ministry_recipient"), permit_condition_id=permit_condition_id, permit_amendment_id=permit_amendment_id, ) return mine_report_permit_requirement, 201 - diff --git a/services/core-api/app/api/mines/response_models.py b/services/core-api/app/api/mines/response_models.py index cbd117ff97..1a54e52bbb 100644 --- a/services/core-api/app/api/mines/response_models.py +++ b/services/core-api/app/api/mines/response_models.py @@ -252,6 +252,7 @@ def format(self, value): MINE_REPORT_PERMIT_REQUIREMENT = api.model( 'MineReportPermitRequirement', { + 'report_name': fields.String, 'mine_report_permit_requirement_id': fields.Integer, 'due_date_period_months': fields.Integer, 'initial_due_date': fields.Date, @@ -895,7 +896,8 @@ def format(self, value): 'parent_permit_condition_id': fields.Integer, 'sub_conditions': fields.List(PermitCondition), 'step': fields.String, - 'display_order': fields.Integer + 'display_order': fields.Integer, + 'meta': fields.Raw, }) PERMIT_CONDITION_TEMPLATE_MODEL = api.model('PermitConditionTemplate', { diff --git a/services/core-api/celery_dev.sh b/services/core-api/celery_dev.sh new file mode 100755 index 0000000000..72b18def44 --- /dev/null +++ b/services/core-api/celery_dev.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# This script is used to start the Celery worker in development mode with auto-reload on file changes. +# The production entrypoint can be found in celery.sh +# -n is the number of tasks to consume +# -A is the name of the app to run +# -Q is the name of the queue to consume from +# -concurrency is the number of child processes processing the queue +# -B is the Beat +# --scheduler is the scheduler class to use +# -s Path to the schedule database. +# -E Enable sending task-related events that can be captured by monitors +# --pidfile is the location of the pid file + + +cd /app || exit + +watchmedo auto-restart --directory=./ --pattern=*.py --recursive -- celery -A app.tasks.celery_entrypoint worker -n core_tasks@%h -Q core_tasks --loglevel=${CELERY_LOG_LEVEL:-info} --concurrency=1 -B --scheduler redbeat.RedBeatScheduler -E diff --git a/services/core-api/tests/permits/permit_extraction/test_create_permit_condition_report_requirement.py b/services/core-api/tests/permits/permit_extraction/test_create_permit_condition_report_requirement.py new file mode 100644 index 0000000000..4bb0b3d245 --- /dev/null +++ b/services/core-api/tests/permits/permit_extraction/test_create_permit_condition_report_requirement.py @@ -0,0 +1,120 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from app.api.mines.permits.permit_extraction.create_permit_condition_report_requirement import ( + create_permit_condition_report_requirement, +) +from app.api.mines.permits.permit_extraction.models.permit_condition_result import ( + PermitConditionResult, +) + + +@pytest.fixture +def mock_task(): + task = MagicMock() + task.permit_amendment.permit_amendment_id = "test-amendment-id" + return task + + +def test_create_report_requirement_with_no_report_required(mock_task): + condition = PermitConditionResult( + condition_text="Test condition text", + meta={"questions": [{"question_key": "require_report", "answer": False}]}, + ) + + result = create_permit_condition_report_requirement(mock_task, condition, "test-id") + assert result is None + + +def test_create_report_requirement_basic(mock_task): + condition = PermitConditionResult( + condition_text="Test condition text", + meta={ + "questions": [ + {"question_key": "require_report", "answer": True}, + {"question_key": "report_name", "answer": "Test Report"}, + {"question_key": "due_date", "answer": "2023-12-31"}, + {"question_key": "recurring", "answer": True}, + {"question_key": "frequency", "answer": "monthly"}, + {"question_key": "mention_chief_inspector", "answer": True}, + {"question_key": "mention_chief_permitting_officer", "answer": False}, + ] + }, + ) + + result = create_permit_condition_report_requirement(mock_task, condition, "test-id") + assert result is not None + assert result.report_name == "Test Report" + assert result.permit_condition_id == "test-id" + assert result.permit_amendment_id == "test-amendment-id" + assert result.cim_or_cpo == "CIM" + assert result.due_date_period_months == 1 + assert result.initial_due_date == datetime(2023, 12, 31) + + +def test_create_report_requirement_both_cim_cpo(mock_task): + condition = PermitConditionResult( + condition_text="Test condition text", + meta={ + "questions": [ + {"question_key": "require_report", "answer": True}, + {"question_key": "mention_chief_inspector", "answer": True}, + {"question_key": "mention_chief_permitting_officer", "answer": True}, + ] + }, + ) + + result = create_permit_condition_report_requirement(mock_task, condition, "test-id") + assert result is not None + assert result.cim_or_cpo == "BOTH" + + +def test_create_report_requirement_various_frequencies(mock_task): + test_cases = [ + ("annually", 12), + ("quarterly", 3), + ("semiannually", 6), + ("as needed", 0), + ("every 5 years", 60), + ] + + for frequency, expected_months in test_cases: + condition = PermitConditionResult( + condition_text="Test condition text", + meta={ + "questions": [ + {"question_key": "require_report", "answer": True}, + {"question_key": "recurring", "answer": True}, + {"question_key": "frequency", "answer": frequency}, + ] + }, + ) + + result = create_permit_condition_report_requirement( + mock_task, condition, "test-id" + ) + assert result is not None + assert result.due_date_period_months == expected_months + + +@patch( + "app.api.mines.permits.permit_extraction.create_permit_condition_report_requirement.current_app" +) +def test_create_report_requirement_invalid_date( + mock_current_app, mock_task, test_client +): + condition = PermitConditionResult( + condition_text="Test condition text", + meta={ + "questions": [ + {"question_key": "require_report", "answer": True}, + {"question_key": "due_date", "answer": "invalid-date"}, + ] + }, + ) + + result = create_permit_condition_report_requirement(mock_task, condition, "test-id") + assert result is not None + assert result.initial_due_date is None + mock_current_app.logger.error.assert_called_once() diff --git a/services/core-api/tests/permits/permit_extraction/test_create_permit_conditions_from_task.py b/services/core-api/tests/permits/permit_extraction/test_create_permit_conditions_from_task.py index b6091f6a73..5652a8a6d5 100644 --- a/services/core-api/tests/permits/permit_extraction/test_create_permit_conditions_from_task.py +++ b/services/core-api/tests/permits/permit_extraction/test_create_permit_conditions_from_task.py @@ -8,6 +8,9 @@ from app.api.mines.permits.permit_extraction.models.permit_extraction_task import ( PermitExtractionTask, ) +from app.api.mines.reports.models.mine_report_permit_requirement import ( + MineReportPermitRequirement, +) from tests.factories import create_mine_and_permit @@ -61,7 +64,7 @@ def permit_conditions(permit_amendment): "section": "A", "paragraph": "1", "subparagraph": "1", - "clause": 'a', + "clause": "a", "subclause": None, "subsubclause": None, "condition_title": None, @@ -71,8 +74,8 @@ def permit_conditions(permit_amendment): "section": "A", "paragraph": "1", "subparagraph": "1", - "clause": 'a', - "subclause": 'b', + "clause": "a", + "subclause": "b", "subsubclause": None, "condition_title": "This condition has a title", "condition_text": "This is a subclause", @@ -129,56 +132,101 @@ def permit_conditions(permit_amendment): return permit_conditions -def test_create_permit_conditions_from_task(permit_conditions, permit_amendment, db_session): + +def test_create_permit_conditions_from_task( + permit_conditions, permit_amendment, db_session +): ### General Section gen_cat = permit_conditions[0] # Top level sections are not created as a PermitCondition. They are mapped to a PermitConditionCategory instead - assert permit_conditions[0].permit_amendment_id == permit_amendment.permit_amendment_id + assert ( + permit_conditions[0].permit_amendment_id == permit_amendment.permit_amendment_id + ) assert permit_conditions[0].condition_category_code != "GEC" assert permit_conditions[0].condition_category.description == "General" assert permit_conditions[0].condition == "This is a paragraph" - assert permit_conditions[0].condition_type_code == "SEC" # First level is a section + assert permit_conditions[0].condition_type_code == "SEC" # First level is a section assert permit_conditions[0].parent_permit_condition_id is None assert permit_conditions[0]._step == "1" - assert permit_conditions[1].permit_amendment_id == permit_amendment.permit_amendment_id - assert permit_conditions[1].condition_category_code == gen_cat.condition_category_code + assert ( + permit_conditions[1].permit_amendment_id == permit_amendment.permit_amendment_id + ) + assert ( + permit_conditions[1].condition_category_code == gen_cat.condition_category_code + ) assert permit_conditions[1].condition == "This is a subparagraph" - assert permit_conditions[1].condition_type_code == "CON" # Second level is a condition - assert permit_conditions[1].parent_permit_condition_id == gen_cat.permit_condition_id + assert ( + permit_conditions[1].condition_type_code == "CON" + ) # Second level is a condition + assert ( + permit_conditions[1].parent_permit_condition_id == gen_cat.permit_condition_id + ) assert permit_conditions[1]._step == "1" - assert permit_conditions[2].permit_amendment_id == permit_amendment.permit_amendment_id - assert permit_conditions[2].condition_category_code == gen_cat.condition_category_code + assert ( + permit_conditions[2].permit_amendment_id == permit_amendment.permit_amendment_id + ) + assert ( + permit_conditions[2].condition_category_code == gen_cat.condition_category_code + ) assert permit_conditions[2].condition == "This is a clause" - assert permit_conditions[2].condition_type_code == "LIS" # Third level on is a list item - assert permit_conditions[2].parent_permit_condition_id == permit_conditions[1].permit_condition_id + assert ( + permit_conditions[2].condition_type_code == "LIS" + ) # Third level on is a list item + assert ( + permit_conditions[2].parent_permit_condition_id + == permit_conditions[1].permit_condition_id + ) assert permit_conditions[2]._step == "a" # When a condition both has a title and text, they are created as two conditions, with the text as a child of the title # Note: This was an assumption made to make the display more accurately reflect the PDF. May need a revision. - assert permit_conditions[3].permit_amendment_id == permit_amendment.permit_amendment_id - assert permit_conditions[3].condition_category_code == gen_cat.condition_category_code + assert ( + permit_conditions[3].permit_amendment_id == permit_amendment.permit_amendment_id + ) + assert ( + permit_conditions[3].condition_category_code == gen_cat.condition_category_code + ) assert permit_conditions[3].condition == "This condition has a title" assert permit_conditions[3].condition_type_code == "LIS" - assert permit_conditions[3].parent_permit_condition_id == permit_conditions[2].permit_condition_id + assert ( + permit_conditions[3].parent_permit_condition_id + == permit_conditions[2].permit_condition_id + ) assert permit_conditions[3]._step == "b" - assert permit_conditions[4].permit_amendment_id == permit_amendment.permit_amendment_id - assert permit_conditions[4].condition_category_code == gen_cat.condition_category_code + assert ( + permit_conditions[4].permit_amendment_id == permit_amendment.permit_amendment_id + ) + assert ( + permit_conditions[4].condition_category_code == gen_cat.condition_category_code + ) assert permit_conditions[4].condition == "This is a subclause" assert permit_conditions[4].condition_type_code == "LIS" - assert permit_conditions[4].parent_permit_condition_id == permit_conditions[3].permit_condition_id - assert permit_conditions[4]._step == "" # This is a child of the title condition - which in the PDFs do not have a step + assert ( + permit_conditions[4].parent_permit_condition_id + == permit_conditions[3].permit_condition_id + ) + assert ( + permit_conditions[4]._step == "" + ) # This is a child of the title condition - which in the PDFs do not have a step -def test_creates_general_conditions_as_unique_for_permit_amendment(permit_conditions, permit_amendment, db_session): +def test_creates_general_conditions_as_unique_for_permit_amendment( + permit_conditions, permit_amendment, db_session +): # Protection of Land and Watercourses Section - assert permit_conditions[5].permit_amendment_id == permit_amendment.permit_amendment_id + assert ( + permit_conditions[5].permit_amendment_id == permit_amendment.permit_amendment_id + ) assert permit_conditions[5].condition_category_code != "ELC" - assert permit_conditions[5].condition_category.description == "Protection of Land and Watercourses" + assert ( + permit_conditions[5].condition_category.description + == "Protection of Land and Watercourses" + ) assert permit_conditions[5].condition == "Another paragraph" assert permit_conditions[5].condition_type_code == "SEC" assert permit_conditions[5].parent_permit_condition_id is None @@ -187,9 +235,169 @@ def test_creates_general_conditions_as_unique_for_permit_amendment(permit_condit def test_creates_custom_conditions(permit_conditions, permit_amendment, db_session): # Can handle custom sections - assert permit_conditions[6].permit_amendment_id == permit_amendment.permit_amendment_id + assert ( + permit_conditions[6].permit_amendment_id == permit_amendment.permit_amendment_id + ) assert permit_conditions[6].condition_category.description == "This is just a test" assert permit_conditions[6].condition == "A test paragraph" assert permit_conditions[6].parent_permit_condition_id is None - assert permit_conditions[6].condition_category.permit_amendment_id == permit_amendment.permit_amendment_id - assert permit_conditions[6]._step == "1" \ No newline at end of file + assert ( + permit_conditions[6].condition_category.permit_amendment_id + == permit_amendment.permit_amendment_id + ) + assert permit_conditions[6]._step == "1" + + +def test_report_requirement_exists(permit_amendment, db_session): + task = PermitExtractionTask( + task_result={ + "conditions": [ + { + "section": "D", + "paragraph": "1", + "subparagraph": None, + "clause": None, + "subclause": None, + "subsubclause": None, + "condition_title": None, + "condition_text": "This is a report requirement", + "meta": { + "questions": [ + {"question_key": "require_report", "answer": True}, + {"question_key": "report_name", "answer": "Test Report"}, + {"question_key": "due_date", "answer": "2023-12-31"}, + {"question_key": "recurring", "answer": True}, + {"question_key": "frequency", "answer": "monthly"}, + {"question_key": "mention_chief_inspector", "answer": True}, + { + "question_key": "mention_chief_permitting_officer", + "answer": False, + }, + ] + }, + } + ] + }, + permit_amendment=permit_amendment, + ) + create_permit_conditions_from_task(task) + report_requirements = MineReportPermitRequirement.query.all() + assert len(report_requirements) == 1 + assert report_requirements[0].report_name == "Test Report" + + +def test_nested_display_order(test_client, db_session, permit_amendment): + # Create a task with nested conditions + task = PermitExtractionTask( + permit_amendment=permit_amendment, + task_result={ + "conditions": [ + { + "section": "A", + "condition_text": "General", + }, + { + "section": "A", + "paragraph": "1", + "condition_text": "First sub-condition", + }, + { + "section": "A", + "paragraph": "2", + "condition_text": "Second sub-condition", + }, + { + "section": "A", + "paragraph": "2", + "subparagraph": "a", + "condition_text": "Nested condition", + }, + { + "section": "A", + "paragraph": "2", + "subparagraph": "a", + "clause": "i", + "condition_text": "Clause", + }, + { + "section": "A", + "paragraph": "2", + "subparagraph": "a", + "clause": "ii", + "condition_text": "Clause2", + }, + {"section": "B", "condition_text": "Another section"}, + ] + }, + ) + + create_permit_conditions_from_task(task) + + # Query conditions and verify display orders + conditions = db_session.query(PermitConditions).all() + + # Create a map of conditions by their text for easier testing + conditions_map = {c.condition: c for c in conditions} + + assert len(conditions_map.keys()) == 5 + + # Verify sub-conditions are top-level (sections are "Categories", so not part of the tree) + assert conditions_map["First sub-condition"].parent_permit_condition_id is None + assert conditions_map["First sub-condition"].display_order == 1 + + assert conditions_map["Second sub-condition"].parent_permit_condition_id is None + assert conditions_map["Second sub-condition"].display_order == 2 + + assert conditions_map["Nested condition"].display_order == 1 + + assert conditions_map["Clause"].display_order == 1 + assert conditions_map["Clause2"].display_order == 2 + + # Verify nested condition is a child of the second sub-condition + assert ( + conditions_map["Nested condition"].parent_permit_condition_id + == conditions_map["Second sub-condition"].permit_condition_id + ) + assert ( + conditions_map["Clause"].parent_permit_condition_id + == conditions_map["Nested condition"].permit_condition_id + ) + assert ( + conditions_map["Clause2"].parent_permit_condition_id + == conditions_map["Nested condition"].permit_condition_id + ) + + +def test_display_order_with_titles(test_client, db_session, permit_amendment): + task = PermitExtractionTask( + permit_amendment=permit_amendment, + task_result={ + "conditions": [ + { + "section": "A", + "condition_text": "Firstt secion", + }, + { + "section": "A", + "paragraph": "1", + "condition_text": "Sub 1", + }, + { + "section": "A", + "paragraph": "2", + "condition_text": "Sub 2", + }, + ] + }, + ) + + create_permit_conditions_from_task(task) + + conditions = db_session.query(PermitConditions).all() + conditions_map = {c.condition: c for c in conditions} + + assert len(conditions_map.keys()) == 2 + + # Verify display orders with titles + assert conditions_map["Sub 1"].display_order == 1 + assert conditions_map["Sub 2"].display_order == 2 diff --git a/services/permits/.env-example b/services/permits/.env-example index 9cee6516db..8dbb89a3d5 100644 --- a/services/permits/.env-example +++ b/services/permits/.env-example @@ -17,7 +17,7 @@ ELASTICSEARCH_PASSWORD=elastic AZURE_API_KEY= AZURE_API_VERSION=2024-02-01 -AZURE_DEPLOYMENT_NAME=mds-permits-turbo +AZURE_DEPLOYMENT_NAME=gpt-4o AZURE_BASE_URL=https://emli-mdsopenai.openai.azure.com/ DEBUG_MODE=true OAUTHLIB_INSECURE_TRANSPORT=1 diff --git a/services/permits/app/permit_condition_prompts.yaml b/services/permits/app/permit_condition_prompts.yaml index b7dc0f418a..587957551e 100644 --- a/services/permits/app/permit_condition_prompts.yaml +++ b/services/permits/app/permit_condition_prompts.yaml @@ -1,23 +1,27 @@ system_prompt: | - You are a helpful AI assistant that can extract information from text files. You extract what is exactly as it is from the text and return it in a json format. + You are a helpful AI assistant that answer questions about paragraphs in a mining permit. You answer exactly what is asked for and return it in a json format. permit_document_prompt_meta_questions: | ----------- {{documents[0].content}} user_prompt_meta_questions: | - Your task is to answer questions for each of the paragraphs in the input text and return the answers in a JSON format. + You are given a list of paragraphs extracted from a mining permit. These paragraphs may have a reporting requirement. Your task is to identify the paragraphs that have a reporting requirement. + A reporting requirement means that a report is mentioned that needs submission, such as an annual report, an entity needs to be notified when something happens, updates or results needs to be submitted or filed to an entity. Please consider synonyms of those terms as well to a reporting requirement. + Some paragraphs may mention "all reports" under the permit, this should not be considered as a reporting requirement as it does not mentions a specific report. - You will be given a mining permit and should answer the questions for each condition of the text in the text. The questions are as follows: + For each of the identified paragraphs that require a report you need to extract the following information: "require_report": Does this paragraph require a report to be submitted? (type: boolean) "due_date": When is the report due date? (type: date) "recurring": Is this a recurring report requirement? (type: boolean) - "frequency": Frequency of the report - "Yearly", "Monthly", "Daily", "Weekly", "Bi-weekly", "Quarterly", "Semi-annually", "Annually", "Bi-annually", "As needed", "Other" (type: string) + "frequency": Frequency of the report - "Yearly", "Monthly", "Quarterly", "Semi-annually", "Annually", "Bi-annually", "5 years", "As needed", "Other" (type: string) "mention_chief_inspector": Does this paragraph mention the Chief Inspector? (type: boolean) "mention_chief_permitting_officer": Does this paragraph mention the Chief Permitting Officer? (type: boolean) + "report_name": Name of the report mentioned (if found) (type: string) - Output should be a json structured as follows: + + Output should be a json blob structured as follows: { "paragraphs": [ { @@ -33,16 +37,9 @@ user_prompt_meta_questions: | ... ] } - - The input is a csv file with the following columns. The values in each column are quoted with "" - - id: id of the paragraph - - text: The text of the paragraph - - Each line in the input should have a corresponding output if an answer to any of the questions is found. - Very Important: Output the full json structure without code blocks or other text. Do not output any other explanation or questions that is not in the json format. - Here's the CSV input (delimited by ---------) + Here's the input (delimited by ---------) # The following prompts are currently not in use. They are kept here for future reference. user_prompt: | diff --git a/services/permits/app/permit_conditions/converters/azure_document_intelligence_converter.py b/services/permits/app/permit_conditions/converters/azure_document_intelligence_converter.py index d6b046619a..dbe0b4602b 100644 --- a/services/permits/app/permit_conditions/converters/azure_document_intelligence_converter.py +++ b/services/permits/app/permit_conditions/converters/azure_document_intelligence_converter.py @@ -1,6 +1,4 @@ -import csv import hashlib -import io import json import logging import os @@ -9,7 +7,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional -import pandas as pd from app.permit_conditions.context import context from azure.ai.formrecognizer import AnalyzeResult, DocumentAnalysisClient from azure.core.credentials import AzureKeyCredential @@ -39,7 +36,7 @@ class AzureDocumentIntelligenceConverter: """ @component.output_types( - documents=List[Document], permit_condition_csv=List[Document] + documents=List[Document] ) def run( self, @@ -71,16 +68,13 @@ def run( docs = [] - for idx, p in enumerate(result.paragraphs): + for idx, p in enumerate(result.paragraphs or []): doc = self.add_metadata_to_document(idx, p) docs.append(doc) - permit_condition_csv = _create_csv_representation(docs) - return { "documents": docs, - "permit_condition_csv": [Document(content=permit_condition_csv)], } def add_metadata_to_document(self, idx, p): @@ -112,11 +106,16 @@ def add_metadata_to_document(self, idx, p): "left": left, }, "role": p.role, + "page": p.bounding_regions[0].page_number, } return Document(content=json.dumps(content, indent=None), meta=meta) def run_document_intelligence(self, file_path): + + assert DOCUMENTINTELLIGENCE_ENDPOINT, "DOCUMENTINTELLIGENCE_ENDPOINT is not set" + assert DOCUMENTINTELLIGENCE_API_KEY, "DOCUMENTINTELLIGENCE_API_KEY is not set" + document_intelligence_client = DocumentAnalysisClient( endpoint=DOCUMENTINTELLIGENCE_ENDPOINT, credential=AzureKeyCredential(DOCUMENTINTELLIGENCE_API_KEY), @@ -149,18 +148,3 @@ def retrieve_cached_result(self, cache_key): logger.info("No cache entry found. Quering Azure Document Intelligence") result = None return result - - -def _create_csv_representation(docs): - content = json.dumps([json.loads(doc.content) for doc in docs]) - jsn = pd.read_json(io.StringIO(content)) - - cs = jsn.to_csv( - index=False, - header=True, - quoting=csv.QUOTE_ALL, - encoding="utf-8", - sep=",", - columns=["id", "text"], - ) - return cs diff --git a/services/permits/app/permit_conditions/converters/filter_conditions_paragraphs.py b/services/permits/app/permit_conditions/converters/filter_conditions_paragraphs.py index 849741a96f..b1a7d72f20 100644 --- a/services/permits/app/permit_conditions/converters/filter_conditions_paragraphs.py +++ b/services/permits/app/permit_conditions/converters/filter_conditions_paragraphs.py @@ -1,11 +1,8 @@ -import csv -import io import json import logging import os from typing import Any, Dict, List, Optional -import pandas as pd from app.permit_conditions.context import context from app.permit_conditions.validator.parse_hierarchy import split_numbering from haystack import Document, component, logging diff --git a/services/permits/app/permit_conditions/converters/metadata_converter.py b/services/permits/app/permit_conditions/converters/metadata_converter.py index 09c21eda47..8c8e18f45d 100644 --- a/services/permits/app/permit_conditions/converters/metadata_converter.py +++ b/services/permits/app/permit_conditions/converters/metadata_converter.py @@ -1,18 +1,15 @@ -import csv -import io import json import logging import os -from typing import Any, Dict, List, Optional +from typing import List -import pandas as pd from app.permit_conditions.context import context from app.permit_conditions.pipelines.chat_data import ChatData from app.permit_conditions.validator.permit_condition_model import ( PermitCondition, PermitConditions, ) -from haystack import Document, component, logging +from haystack import component, logging logger = logging.getLogger(__name__) @@ -24,6 +21,7 @@ class ConditionsMetadataCombiner: """ Combines answers given by GPT4 to the questions defined in our prompts with the permit conditions extracted from the permit document. + Responses from GPT4 are stored in the `meta` attribute of the permit condition it relates to. Args: conditions (List[PermitCondition]): List of permit conditions. data (ChatData): Chat data containing messages from GPT4. @@ -43,26 +41,23 @@ def run( ) docs_by_id = {doc.id: doc for doc in conditions.conditions} - paragraphs = [] - # Extract paragraphs from the chat data - for msg in data.messages: - cnt = json.loads(msg.content) + flattened_messages = [msg for group in data.messages for msg in group] + content = [json.loads(msg.content) for msg in flattened_messages] - for p in cnt["paragraphs"]: - # Sometimes the paragraphs are nesteded in the output from GPT4 - if "paragraphs" in p: - for p2 in p["paragraphs"]: - paragraphs.append(p2) - else: - paragraphs.append(p) + # sometimes the paragraphs are nested in the output from GPT4 + for paragraph in content: + if "paragraphs" in paragraph: + paragraphs.extend(paragraph["paragraphs"]) + else: + paragraphs.append(paragraph) # Add questions answered by GPT4 to the metadata of the condition in the `questions` property for p in paragraphs: if p["id"] in docs_by_id: docs_by_id[p["id"]].meta = { "questions": p["meta"], - **docs_by_id[p["id"]].meta, + **(docs_by_id[p["id"]].meta or {}), } return {"conditions": conditions} diff --git a/services/permits/app/permit_conditions/converters/pdf_to_text_converter.py b/services/permits/app/permit_conditions/converters/pdf_to_text_converter.py index 04b54facd7..9018fe3d8f 100644 --- a/services/permits/app/permit_conditions/converters/pdf_to_text_converter.py +++ b/services/permits/app/permit_conditions/converters/pdf_to_text_converter.py @@ -2,7 +2,6 @@ import os import shutil from pathlib import Path -from time import sleep from typing import Any, Dict, List, Optional import ocrmypdf diff --git a/services/permits/app/permit_conditions/pipelines/CachedAzureOpenAIChatGenerator.py b/services/permits/app/permit_conditions/pipelines/CachedAzureOpenAIChatGenerator.py index 8bee643f67..7f87588164 100644 --- a/services/permits/app/permit_conditions/pipelines/CachedAzureOpenAIChatGenerator.py +++ b/services/permits/app/permit_conditions/pipelines/CachedAzureOpenAIChatGenerator.py @@ -1,26 +1,36 @@ +import concurrent.futures import hashlib import logging import os -import pickle import struct +from typing import List, Optional from app.permit_conditions.pipelines.chat_data import ChatData +from haystack import Document, component +from haystack.components.caching import CacheChecker from haystack.components.generators.chat import AzureOpenAIChatGenerator from haystack.dataclasses import ChatMessage -from haystack import component, Document -from haystack.components.caching import CacheChecker -from haystack_integrations.document_stores.elasticsearch import ElasticsearchDocumentStore +from haystack.document_stores.types import DuplicatePolicy +from haystack_integrations.document_stores.elasticsearch import ( + ElasticsearchDocumentStore, +) ROOT_DIR = os.path.abspath(os.curdir) logger = logging.getLogger(__name__) DEBUG_MODE = os.environ.get("DEBUG_MODE", "False").lower() == "true" AZURE_DEPLOYMENT_NAME = os.environ.get("AZURE_DEPLOYMENT_NAME") + +if not AZURE_DEPLOYMENT_NAME: + raise ValueError("AZURE_DEPLOYMENT_NAME environment variable is not set.") + ca_cert = os.environ.get("ELASTICSEARCH_CA_CERT", None) host = os.environ.get("ELASTICSEARCH_HOST", None) or "https://elasticsearch:9200" username = os.environ.get("ELASTICSEARCH_USERNAME", "") password = os.environ.get("ELASTICSEARCH_PASSWORD", "") +MAX_WORKERS = 5 + def hash_messages(messages): """ @@ -33,7 +43,7 @@ def hash_messages(messages): str: The SHA256 hash digest of the messages. """ - to_hash = messages + [ChatMessage.from_user(AZURE_DEPLOYMENT_NAME)] + to_hash = messages + [ChatMessage.from_user(AZURE_DEPLOYMENT_NAME or "")] hsh = hashlib.sha256() for message in to_hash: @@ -59,7 +69,7 @@ def __init__(self, **kwargs): Note: This currently only caches responses locally on the filesystem and should only be used locally. """ - def fetch_result(self, messages, generation_kwargs): + def fetch_result(self, messages: List[ChatMessage], generation_kwargs): """ Fetches the chat generation result from the cache or queries the OpenAI API. @@ -74,58 +84,77 @@ def fetch_result(self, messages, generation_kwargs): existing_reply_found = False cache_key = hash_messages(messages) - document_store = ElasticsearchDocumentStore(hosts=host, - basic_auth=(username, password), - index="permits", - embedding_similarity_function="cosine", - ca_certs=ca_cert if ca_cert else None, - verify_certs=True if ca_cert else False) - - cache_checker = CacheChecker(document_store=document_store, cache_field="cache_key") + document_store = ElasticsearchDocumentStore( + hosts=host, + basic_auth=(username, password), + index="permits", + embedding_similarity_function="cosine", + ca_certs=ca_cert if ca_cert else None, + verify_certs=True if ca_cert else False, + ) + + cache_checker = CacheChecker( + document_store=document_store, cache_field="cache_key" + ) cached_result = cache_checker.run(items=[cache_key]) if len(cached_result["hits"]) > 0: existing_reply_found = True logger.info("cached_result: %s", cached_result) - res = {"replies": [ChatMessage(content=cached_result["hits"][0].content, - name=cached_result["hits"][0].meta["name"], - role=cached_result["hits"][0].meta["role"], - meta=cached_result["hits"][0].meta)]} + res = { + "replies": [ + ChatMessage( + content=cached_result["hits"][0].content, + name=cached_result["hits"][0].meta["name"], + role=cached_result["hits"][0].meta["role"], + meta=cached_result["hits"][0].meta, + ) + ] + } + return res["replies"][0] if not existing_reply_found: try: res = super(CachedAzureOpenAIChatGenerator, self).run( messages=messages, generation_kwargs=generation_kwargs ) + + documents = [ + Document( + content=res["replies"][0].content, + meta={ + "cache_key": cache_key, + "name": res["replies"][0].name, + "role": res["replies"][0].role, + "model": res["replies"][0].meta["model"], + "index": res["replies"][0].meta["index"], + "finish_reason": res["replies"][0].meta["finish_reason"], # + "usage": { + "completion_tokens": res["replies"][0].meta["usage"][ + "completion_tokens" + ], + "prompt_tokens": res["replies"][0].meta["usage"][ + "prompt_tokens" + ], + "total_tokens": res["replies"][0].meta["usage"][ + "total_tokens" + ], + }, + }, + ) + ] + document_store.write_documents( + documents, policy=DuplicatePolicy.OVERWRITE + ) + return res["replies"][0] + except Exception as e: logger.error(f"Error while querying OpenAI: {e}") raise - documents = [ - Document(content=res["replies"][0].content, meta={"cache_key": cache_key, - "name": res["replies"][0].name, - "role": res["replies"][0].role, - "model": res["replies"][0].meta["model"], - "index": res["replies"][0].meta["index"], - "finish_reason": res["replies"][0].meta[ - "finish_reason"], # - "usage": {"completion_tokens": - res["replies"][0].meta["usage"][ - "completion_tokens"], - "prompt_tokens": - res["replies"][0].meta["usage"][ - "prompt_tokens"], - "total_tokens": - res["replies"][0].meta["usage"][ - "total_tokens"]} - }) - ] - document_store.write_documents(documents) - return res["replies"][0] - @component.output_types(data=ChatData) def run(self, data: ChatData, generation_kwargs=None, iteration=0): self.it += 1 """ - Runs the chat generation process. + Runs the chat generation process in parallel with max 3 concurrent executions. Args: data (ChatData): The input chat data. @@ -135,7 +164,46 @@ def run(self, data: ChatData, generation_kwargs=None, iteration=0): Returns: dict: The output chat data. """ - reply = self.fetch_result(data.messages, generation_kwargs) + results: List[List[ChatMessage]] = [[] for _ in range(len(data.messages))] + + def process_message(args): + idx, messages = args + return idx, self.run_for_message(messages, generation_kwargs, idx + self.it) + + with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + futures = [ + executor.submit(process_message, (idx, messages)) + for idx, messages in enumerate(data.messages) + ] + + for future in concurrent.futures.as_completed(futures): + idx, result = future.result() + if result is not None: + results[idx].append(result) + + if DEBUG_MODE: + with open("debug/cached_azure_openai_chat_generator_output.txt", "w") as f: + for reply in results: + for r in reply: + f.write(r.content) + f.write("\n") + + return {"data": ChatData(messages=results, documents=data.documents)} + + def run_for_message( + self, messages: List[ChatMessage], generation_kwargs=None, iteration=0 + ) -> Optional[ChatMessage]: + + reply = self.fetch_result(messages, generation_kwargs) + + if reply is None: + return None + + with open( + f"debug/cached_azure_openai_chat_generator_output_{iteration}.txt", + "w", + ) as f: + f.write(reply.content) content = reply.content completion_tokens = reply.meta["usage"]["completion_tokens"] @@ -145,18 +213,29 @@ def run(self, data: ChatData, generation_kwargs=None, iteration=0): # If the response is too long for GPT4 to complete (returned tokens > 4096), ask GPT4 to continue the query # limit the number of iterations to 10 to avoid issues if GPT4 for some reason # keeps returning partial responses - while reply.meta["finish_reason"] == "length" and iteration < 10: + + lp = 0 + while ( + reply is not None + and reply.meta["finish_reason"] == "length" + and iteration < 10 + ): + lp += 1 logger.info( f"Partial json generated continuing query. Iteration: {iteration}" ) - messages = data.messages + [ + messages = messages + [ reply, ChatMessage.from_user( "Your response got cut off. Continue from where you left off." ), ] reply = self.fetch_result(messages, generation_kwargs) + + if reply is None: + reply = None + continue content += reply.content # Sum up the usage tokens to make this continuation process transparent to the other components @@ -167,20 +246,19 @@ def run(self, data: ChatData, generation_kwargs=None, iteration=0): iteration += 1 if DEBUG_MODE: with open( - f"debug/cached_azure_openai_chat_generator_output_{self.it}_{iteration}.txt", - "w", + f"debug/cached_azure_openai_chat_generator_output_{iteration}_{lp}.txt", + "w", ) as f: f.write(reply.content) + if reply is None: + reply = ChatMessage( + content=content, role=messages[0].role, name=messages[0].name + ) + reply.content = content reply.meta["usage"]["completion_tokens"] = completion_tokens reply.meta["usage"]["prompt_tokens"] = prompt_tokens reply.meta["usage"]["total_tokens"] = total_tokens - if DEBUG_MODE: - with open( - f"debug/cached_azure_openai_chat_generator_output_{self.it}.txt", "w" - ) as f: - f.write(reply.content) - - return {"data": ChatData([reply], data.documents)} + return reply diff --git a/services/permits/app/permit_conditions/pipelines/PaginatedChatPromptBuilder.py b/services/permits/app/permit_conditions/pipelines/PaginatedChatPromptBuilder.py index df105b34b9..3796eee4fa 100644 --- a/services/permits/app/permit_conditions/pipelines/PaginatedChatPromptBuilder.py +++ b/services/permits/app/permit_conditions/pipelines/PaginatedChatPromptBuilder.py @@ -1,9 +1,13 @@ import logging import os -from typing import Optional +from typing import List, Optional from app.permit_conditions.pipelines.chat_data import ChatData -from haystack import component +from app.permit_conditions.validator.permit_condition_model import ( + PermitCondition, + PermitConditions, +) +from haystack import Document, component from haystack.components.builders import ChatPromptBuilder logger = logging.getLogger(__name__) @@ -16,18 +20,23 @@ class PaginatedChatPromptBuilder(ChatPromptBuilder): Component that renders chat prompts using Jinja templates for the use in further steps of the pipeline. - This component extends the ChatPromptBuilder component to support pagination of the chat prompts + This component extends the ChatPromptBuilder component to support pagination of the chat prompts. + + The output of this component is a list of chat prompt "groups" where each group contains a list of chat prompts related to a subset of the permit conditions. """ @component.output_types(data=ChatData) def run( self, + conditions: PermitConditions, iteration: Optional[dict] = None, template=None, template_variables=None, **kwargs, ): + if not template_variables: + template_variables = {} if iteration: logger.info( f"Processing pages starting from page {iteration['start_page']}" @@ -38,12 +47,77 @@ def run( template_variables = {**template_variables, **iteration} else: logger.info("Processing pages starting from page 0") - output = super(PaginatedChatPromptBuilder, self).run( - template=template, template_variables=template_variables, **kwargs - ) - - if DEBUG_MODE: - with open("debug/paginated_chat_prompt_builder_output.txt", "a") as f: - for prompt in output["prompt"]: - f.write(prompt.content + "\n\n") - return {"data": ChatData(output["prompt"], kwargs["documents"])} + + # Split conditions into groups + grouped_conditions = self.split_conditions(conditions.conditions) + + prompts = [] + for idx, group in enumerate(grouped_conditions): + template_variables["documents"] = _format_condition_text_for_prompt(group) + output = super(PaginatedChatPromptBuilder, self).run( + template=template, template_variables=template_variables, **kwargs + ) + prompts.append(output["prompt"]) + + if DEBUG_MODE: + with open( + f"debug/paginated_chat_puilder_output_{idx + 1}.txt", "a" + ) as f: + for prompt in output["prompt"]: + f.write(prompt.content + "\n\n") + + return {"data": ChatData(prompts, kwargs["documents"])} + + def split_conditions( + self, conditions: List[PermitCondition] + ) -> List[List[PermitCondition]]: + """ + Splits a list of permit conditions into smaller groups based on section and paragraph numbers. + - A group is created per section + - A new group is created when the paragraph number changes and the current group has 30 or more conditions + - Conditions within the same section and paragraph are kept together unless the group size limit is reached + + Why? Accuracy of GPT4 seems to decrease for long prompts. + Args: + conditions (List[PermitCondition]): A list of PermitCondition objects to be grouped + Returns: + List[List[PermitCondition]]: A list of groups, where each group is a list of PermitCondition objects + that share the same section and (usually) the same paragraph number + """ + + grouped_conditions = [] + current_group = [] + current_section = None + current_subsection = None + + for condition in conditions: + if condition.section != current_section: + if current_group: + grouped_conditions.append(current_group) + current_group = [] + current_section = condition.section + current_subsection = condition.paragraph + elif condition.paragraph != current_subsection and len(current_group) >= 30: + grouped_conditions.append(current_group) + current_group = [] + current_subsection = condition.paragraph + + current_group.append(condition) + + if current_group: + grouped_conditions.append(current_group) + + return grouped_conditions + + +def _format_condition_text_for_prompt( + conditions: List[PermitCondition], +) -> List[Document]: + # Format the conditions for the prompt including the condition text and the condition id, indenting the text based on the section, paragraph, etc. + # A. General (id: abc123) + # 1. This is a test. (id: 123) + # a) This is another test. (id: 456) + text_representation = "\n".join( + [f"{c.formatted_text} (id: {c.id})" for c in conditions] + ) + return [Document(content=text_representation)] diff --git a/services/permits/app/permit_conditions/pipelines/chat_data.py b/services/permits/app/permit_conditions/pipelines/chat_data.py index e9d497a60b..3ba84b31d0 100644 --- a/services/permits/app/permit_conditions/pipelines/chat_data.py +++ b/services/permits/app/permit_conditions/pipelines/chat_data.py @@ -7,5 +7,5 @@ @dataclass class ChatData: - messages: List[ChatMessage] - documents: List[Document] + messages: List[List[ChatMessage]] # Groups of message requests and responses from Azure OpenAI. Each group corresponds to one request-response cycle. + documents: List[Document] # Original documents that were used to generate the chat messages. diff --git a/services/permits/app/permit_conditions/pipelines/permit_condition_pipeline.py b/services/permits/app/permit_conditions/pipelines/permit_condition_pipeline.py index 7b87b65b71..2911a4a395 100644 --- a/services/permits/app/permit_conditions/pipelines/permit_condition_pipeline.py +++ b/services/permits/app/permit_conditions/pipelines/permit_condition_pipeline.py @@ -33,12 +33,12 @@ ROOT_DIR = os.path.abspath(os.curdir) -api_key = os.environ.get("AZURE_API_KEY") +api_key = os.environ.get("AZURE_API_KEY", "") deployment_name = os.environ.get("AZURE_DEPLOYMENT_NAME") base_url = os.environ.get("AZURE_BASE_URL") -api_version = os.environ.get("AZURE_API_VERSION") +api_version = os.environ.get("AZURE_API_VERSION","") -assert api_key +assert api_key and api_key is not None assert deployment_name assert base_url assert api_version @@ -74,8 +74,8 @@ def permit_condition_pipeline(): ] ) - temperature = 0.7 - max_tokens = 4096 + temperature = 0 + max_tokens = 16384 llm = CachedAzureOpenAIChatGenerator( azure_endpoint=base_url, @@ -113,7 +113,7 @@ def permit_condition_pipeline(): index_pipeline.connect("filter_paragraphs", "parse_hierarchy") index_pipeline.connect( - "pdf_converter.permit_condition_csv", "prompt_builder.documents" + "parse_hierarchy.conditions", "prompt_builder.conditions" ) index_pipeline.connect("prompt_builder", "llm") index_pipeline.connect("llm", "json_fixer") @@ -143,8 +143,8 @@ def permit_condition_gpt_pipeline(): ] ) - temperature = 0.7 - max_tokens = 4096 + temperature = 0 + max_tokens = 16384 llm = CachedAzureOpenAIChatGenerator( azure_endpoint=base_url, diff --git a/services/permits/app/permit_conditions/resources/permit_condition_resource.py b/services/permits/app/permit_conditions/resources/permit_condition_resource.py index cf5315b38e..43003f4e13 100644 --- a/services/permits/app/permit_conditions/resources/permit_condition_resource.py +++ b/services/permits/app/permit_conditions/resources/permit_condition_resource.py @@ -90,7 +90,7 @@ def status(task_id: str) -> JobStatus: response_model=PermitConditions, responses={202: {"model": InProgressJobStatusResponse}}, ) -def results(task_id: str) -> PermitConditions: +def results(task_id: str) -> JSONResponse | PermitConditions: """ Get the results of a permit conditions extraction job. Args: diff --git a/services/permits/app/permit_conditions/validator/json_fixer.py b/services/permits/app/permit_conditions/validator/json_fixer.py index 918085613d..931815ff71 100644 --- a/services/permits/app/permit_conditions/validator/json_fixer.py +++ b/services/permits/app/permit_conditions/validator/json_fixer.py @@ -12,6 +12,7 @@ DEBUG_MODE = os.environ.get("DEBUG_MODE", "False").lower() == "true" + @component class JSONRepair: @component.output_types(data=ChatData) @@ -28,12 +29,21 @@ def run(self, data: ChatData): dict: A dictionary containing the repaired ChatData object. """ - for msg in data.messages: - msg.content = json.dumps(json.loads(repair_json(msg.content))) + for group in data.messages: + for msg in group: + msg.content = json.dumps(json.loads(str(repair_json(msg.content)))) - if DEBUG_MODE: with open("debug/json_repair_output.txt", "a") as f: - f.write(json.dumps([json.loads(msg.content) for msg in data.messages], indent=4)) + f.write( + json.dumps( + [ + json.loads(msg.content) + for msg in group + for _ in data.messages + ], + indent=4, + ) + ) return {"data": data} diff --git a/services/permits/app/permit_conditions/validator/permit_condition_model.py b/services/permits/app/permit_conditions/validator/permit_condition_model.py index 8344568899..0f14888253 100644 --- a/services/permits/app/permit_conditions/validator/permit_condition_model.py +++ b/services/permits/app/permit_conditions/validator/permit_condition_model.py @@ -49,7 +49,40 @@ def __init__(self, /, **data: Any): data[key] = data[key].strip() super(PermitCondition, self).__init__(**data) + + @property + def formatted_text(self) -> str: + if not self.condition_text: + return '' + + indent_level = 0 + last_defined = None + for key in [ + "section", + "paragraph", + "subparagraph", + "clause", + "subclause", + "subsubclause", + ]: + if getattr(self, key) is not None: + indent_level += 1 + last_defined = key + + formatted_text = self.condition_text + + if last_defined: + last_value = getattr(self, last_defined) + if last_defined == "section": + formatted_text = f"{last_value}. {formatted_text}" + else: + formatted_text = f"({last_value}) {formatted_text}" + + + formatted_text = " " * indent_level + formatted_text + + return formatted_text class PermitConditions(BaseModel): conditions: List[PermitCondition] diff --git a/services/permits/app/permit_conditions/validator/permit_condition_section_combiner.py b/services/permits/app/permit_conditions/validator/permit_condition_section_combiner.py index 5075273a0f..b0a07a7224 100644 --- a/services/permits/app/permit_conditions/validator/permit_condition_section_combiner.py +++ b/services/permits/app/permit_conditions/validator/permit_condition_section_combiner.py @@ -1,7 +1,7 @@ import json import logging import os -from typing import List, Optional +from typing import List from app.permit_conditions.validator.parse_hierarchy import parse_hierarchy from app.permit_conditions.validator.permit_condition_model import ( @@ -24,9 +24,7 @@ class ExtractionIteration(BaseModel): @component class PermitConditionSectionCombiner: - @component.output_types( - conditions=PermitConditions, - ) + @component.output_types(conditions=PermitConditions) def run(self, documents: List[Document]): """ Given a list of documents that have their numbering identified (e.g.) a, B, 1, 3, and bounding boxes identified, this step will @@ -117,7 +115,9 @@ def run(self, documents: List[Document]): matching_cond.id = p["id"] else: - matching_cond.condition_text = f"{matching_cond.condition_text}\n{p['text']}" + matching_cond.condition_text = ( + f"{matching_cond.condition_text}\n{p['text']}" + ) if matching_cond.meta["bounding_box"] and p["meta"]["bounding_box"]: self._combine_bounding_boxes(p, matching_cond) diff --git a/services/permits/app/permit_conditions/validator/permit_condition_validator.py b/services/permits/app/permit_conditions/validator/permit_condition_validator.py index e125400e77..145c40a543 100644 --- a/services/permits/app/permit_conditions/validator/permit_condition_validator.py +++ b/services/permits/app/permit_conditions/validator/permit_condition_validator.py @@ -53,7 +53,7 @@ def run(self, data: ChatData): # Parse the replies given and make sure they're valid json conditions: List[PermitCondition] = reduce( - operator.concat, [self._parse_reply(reply) for reply in data.messages] + operator.concat, [self._parse_reply(reply) for reply in data.messages[0]] ) # Find the content of the last condition that was processed diff --git a/services/permits/requirements.txt b/services/permits/requirements.txt index b915e8c0da..48c0fbedb6 100644 --- a/services/permits/requirements.txt +++ b/services/permits/requirements.txt @@ -8,7 +8,7 @@ Authlib~=1.3.0 requests~=2.32.3 python-dotenv==1.0.0 boto3==1.34.139 -fastapi==0.108.0 +fastapi==0.115.6 pdf2image==1.17.0 Pillow==10.3.0 pdfminer.six==20231228 diff --git a/services/permits/tests/test_azure_document_intelligence_converter.py b/services/permits/tests/test_azure_document_intelligence_converter.py index 81b7259bc1..f3f88b31a1 100644 --- a/services/permits/tests/test_azure_document_intelligence_converter.py +++ b/services/permits/tests/test_azure_document_intelligence_converter.py @@ -5,7 +5,6 @@ import pytest from app.permit_conditions.converters.azure_document_intelligence_converter import ( AzureDocumentIntelligenceConverter, - _create_csv_representation, ) from app.permit_conditions.tasks.tasks import task_context from tests.mocks import MockContext @@ -45,7 +44,8 @@ def test_run(mock_client, converter, tmp_path): mock.Mock(x=1, y=2), mock.Mock(x=3, y=4), mock.Mock(x=5, y=6), - ] + ], + page_number=1 ) ], ), @@ -58,7 +58,8 @@ def test_run(mock_client, converter, tmp_path): mock.Mock(x=2, y=2), mock.Mock(x=3, y=9), mock.Mock(x=5, y=6), - ] + ], + page_number=2 ) ], ), @@ -71,19 +72,14 @@ def test_run(mock_client, converter, tmp_path): assert isinstance(result, dict) assert "documents" in result - assert "permit_condition_csv" in result documents = result["documents"] - permit_condition_csv = result["permit_condition_csv"] assert isinstance(documents, list) - assert isinstance(permit_condition_csv, list) assert len(documents) == 2 - assert len(permit_condition_csv) == 1 document = documents[0] - csv_document = permit_condition_csv[0] res = json.loads(document.content) @@ -99,12 +95,10 @@ def test_run(mock_client, converter, tmp_path): "bottom": 6, "left": 1, }, + "page": 1, "role": "Test role", } - res2 = json.loads(documents[1].content) - - assert csv_document.content == f'"id","text"\n"{res['id']}","Test paragraph"\n"{res2['id']}","Test paragraph2"\n' def test_add_metadata_to_document(converter): @@ -118,7 +112,8 @@ def test_add_metadata_to_document(converter): mock.Mock(x=1, y=2), mock.Mock(x=3, y=4), mock.Mock(x=5, y=6), - ] + ], + page_number=2 ) ], ) @@ -135,5 +130,6 @@ def test_add_metadata_to_document(converter): "bottom": 6, "left": 1, }, + "page": 2, "role": "Test role", } diff --git a/services/permits/tests/test_cached_azure_openai_chat_generator.py b/services/permits/tests/test_cached_azure_openai_chat_generator.py index ee8d87e5f0..61f1eb351d 100644 --- a/services/permits/tests/test_cached_azure_openai_chat_generator.py +++ b/services/permits/tests/test_cached_azure_openai_chat_generator.py @@ -1,16 +1,13 @@ import os -import pickle -from unittest.mock import MagicMock, mock_open, patch +from unittest.mock import MagicMock, patch import pytest -from haystack import Document - from app.permit_conditions.pipelines.CachedAzureOpenAIChatGenerator import ( CachedAzureOpenAIChatGenerator, ) from app.permit_conditions.pipelines.chat_data import ChatData +from haystack import Document from haystack.dataclasses import ChatMessage -from haystack.components.caching import CacheChecker logger = MagicMock() @@ -29,7 +26,7 @@ def set_env(): def test_run_with_valid_data(): - data = ChatData(messages=[ChatMessage.from_user("test_message")], documents=[]) + data = ChatData(messages=[[ChatMessage.from_user("test_message")]], documents=[]) generation_kwargs = {} expected_reply = ChatMessage( content="Mocked reply", @@ -42,16 +39,16 @@ def test_run_with_valid_data(): ) with patch.object( - CachedAzureOpenAIChatGenerator, "fetch_result", return_value=expected_reply + CachedAzureOpenAIChatGenerator, "fetch_result", return_value=expected_reply ): generator = CachedAzureOpenAIChatGenerator() result = generator.run(data, generation_kwargs) - assert result["data"].messages[0].content == expected_reply.content + assert result["data"].messages[0][0].content == expected_reply.content def test_run_with_valid_data_multiple_iterations(): - data = ChatData(messages=[ChatMessage.from_user("test_message")], documents=[]) + data = ChatData(messages=[[ChatMessage.from_user("test_message")]], documents=[]) generation_kwargs = {} # Test a scenario where the response is too long for GPT4 (stops with reason: length) and require @@ -76,9 +73,9 @@ def test_run_with_valid_data_multiple_iterations(): ) with patch.object( - CachedAzureOpenAIChatGenerator, - "fetch_result", - side_effect=[expected_reply, expected_reply2], + CachedAzureOpenAIChatGenerator, + "fetch_result", + side_effect=[expected_reply, expected_reply2], ) as mock_fetch_result: generator = CachedAzureOpenAIChatGenerator() @@ -88,30 +85,42 @@ def test_run_with_valid_data_multiple_iterations(): chat_response = result["data"].messages[0] # Response for each continuation request should be concatinated - assert chat_response.content == "Mocked replyreply continued" + assert chat_response[0].content == "Mocked replyreply continued" # and the usage tokens should be summed up - assert chat_response.meta["usage"]["total_tokens"] == 18 - assert chat_response.meta["usage"]["completion_tokens"] == 11 - assert chat_response.meta["usage"]["prompt_tokens"] == 7 + assert chat_response[0].meta["usage"]["total_tokens"] == 18 + assert chat_response[0].meta["usage"]["completion_tokens"] == 11 + assert chat_response[0].meta["usage"]["prompt_tokens"] == 7 # Make sure the second iteration contained the reply from the first iteration # and a command to continue the generation mock_fetch_result.assert_called_with( - data.messages + [expected_reply, ChatMessage.from_user("Your response got cut off. Continue from where you left off.")], {} + data.messages[0] + + [ + expected_reply, + ChatMessage.from_user( + "Your response got cut off. Continue from where you left off." + ), + ], + {}, ) def test_fetch_result_with_cache_hit(): - with patch.dict('os.environ', { - 'DEBUG_MODE': 'false', - 'ELASTICSEARCH_CA_CERT': 'mock_ca_cert', - 'ELASTICSEARCH_HOST': 'mock_host', - 'ELASTICSEARCH_USERNAME': 'mock_username', - 'ELASTICSEARCH_PASSWORD': 'mock_password', - }): - with patch('app.permit_conditions.pipelines.CachedAzureOpenAIChatGenerator.hash_messages', - return_value='mock_cache_key') as mock_hash_messages: + with patch.dict( + "os.environ", + { + "DEBUG_MODE": "false", + "ELASTICSEARCH_CA_CERT": "mock_ca_cert", + "ELASTICSEARCH_HOST": "mock_host", + "ELASTICSEARCH_USERNAME": "mock_username", + "ELASTICSEARCH_PASSWORD": "mock_password", + }, + ): + with patch( + "app.permit_conditions.pipelines.CachedAzureOpenAIChatGenerator.hash_messages", + return_value="mock_cache_key", + ) as mock_hash_messages: mock_document = MagicMock(spec=Document) mock_document.content = "mock_content" mock_document.meta = { @@ -123,31 +132,37 @@ def test_fetch_result_with_cache_hit(): "usage": { "completion_tokens": 10, "prompt_tokens": 5, - "total_tokens": 15 - } + "total_tokens": 15, + }, } expected_reply = ChatMessage( content=mock_document.content, name=mock_document.meta["name"], role=mock_document.meta["role"], - meta=mock_document.meta + meta=mock_document.meta, ) with patch( - 'app.permit_conditions.pipelines.CachedAzureOpenAIChatGenerator.ElasticsearchDocumentStore') as MockElasticsearchDocumentStore: + "app.permit_conditions.pipelines.CachedAzureOpenAIChatGenerator.ElasticsearchDocumentStore" + ) as MockElasticsearchDocumentStore: with patch( - 'app.permit_conditions.pipelines.CachedAzureOpenAIChatGenerator.CacheChecker') as MockCacheChecker: + "app.permit_conditions.pipelines.CachedAzureOpenAIChatGenerator.CacheChecker" + ) as MockCacheChecker: mock_cache_checker_instance = MockCacheChecker.return_value - mock_cache_checker_instance.run.return_value = {"hits": [mock_document]} + mock_cache_checker_instance.run.return_value = { + "hits": [mock_document] + } generator = CachedAzureOpenAIChatGenerator() result = generator.fetch_result(messages=[], generation_kwargs={}) mock_hash_messages.assert_called_once() - mock_cache_checker_instance.run.assert_called_once_with(items=['mock_cache_key']) - + mock_cache_checker_instance.run.assert_called_once_with( + items=["mock_cache_key"] + ) + assert result is not None assert result.content == expected_reply.content assert result.name == expected_reply.name assert result.role == expected_reply.role @@ -155,3 +170,41 @@ def test_fetch_result_with_cache_hit(): MockElasticsearchDocumentStore.assert_called_once() + +def test_run_with_multiple_message_groups_calls_gpt_for_both(): + messages1 = [ChatMessage.from_user("test_message_1")] + messages2 = [ChatMessage.from_user("test_message_2")] + data = ChatData(messages=[messages1, messages2], documents=[]) + generation_kwargs = {} + + expected_reply1 = ChatMessage( + content="Mocked reply 1", + role="assistant", + name=None, + meta={ + "usage": {"completion_tokens": 10, "prompt_tokens": 5, "total_tokens": 15}, + "finish_reason": "stop", + }, + ) + + expected_reply2 = ChatMessage( + content="Mocked reply 2", + role="assistant", + name=None, + meta={ + "usage": {"completion_tokens": 8, "prompt_tokens": 4, "total_tokens": 12}, + "finish_reason": "stop", + }, + ) + + with patch.object( + CachedAzureOpenAIChatGenerator, + "fetch_result", + side_effect=[expected_reply1, expected_reply2], + ): + generator = CachedAzureOpenAIChatGenerator() + result = generator.run(data, generation_kwargs) + + assert len(result["data"].messages) == 2 + assert result["data"].messages[0][0].content == expected_reply1.content + assert result["data"].messages[1][0].content == expected_reply2.content diff --git a/services/permits/tests/test_conditions_metadata_combiner.py b/services/permits/tests/test_conditions_metadata_combiner.py index 2b95238005..7c22077d9f 100644 --- a/services/permits/tests/test_conditions_metadata_combiner.py +++ b/services/permits/tests/test_conditions_metadata_combiner.py @@ -1,12 +1,19 @@ -import os import json +import os from unittest.mock import MagicMock + import pytest -from haystack import Document -from app.permit_conditions.converters.metadata_converter import ConditionsMetadataCombiner -from app.permit_conditions.validator.permit_condition_model import PermitCondition, PermitConditions +from app.permit_conditions.converters.metadata_converter import ( + ConditionsMetadataCombiner, +) from app.permit_conditions.pipelines.chat_data import ChatData from app.permit_conditions.tasks.tasks import task_context +from app.permit_conditions.validator.permit_condition_model import ( + PermitCondition, + PermitConditions, +) +from haystack import Document +from haystack.dataclasses import ChatMessage from tests.mocks import MockContext logger = MagicMock() @@ -18,13 +25,33 @@ def set_env(): def test_conditions_metadata_combiner(): - conditions = PermitConditions(conditions=[ - PermitCondition(id="abc123", text="condition 1", meta={'bounding_box': {'top': 1}}), - PermitCondition(id="abc234", text="condition 2", meta={'bounding_box': {'top': 2}}), - ]) - data = ChatData(messages=[ - Document(content=json.dumps({"paragraphs": [{"id": "abc123", "meta": {"question": "Answer 1"}}, {"id": "abc234", "meta": {"question": "Answer 2"}}]})) - ], documents=None) + conditions = PermitConditions( + conditions=[ + PermitCondition( + id="abc123", text="condition 1", meta={"bounding_box": {"top": 1}} + ), + PermitCondition( + id="abc234", text="condition 2", meta={"bounding_box": {"top": 2}} + ), + ] + ) + data = ChatData( + messages=[ + [ + ChatMessage.from_system( + content=json.dumps( + { + "paragraphs": [ + {"id": "abc123", "meta": {"question": "Answer 1"}}, + {"id": "abc234", "meta": {"question": "Answer 2"}}, + ] + } + ) + ) + ] + ], + documents=[], + ) with task_context(MockContext()): combiner = ConditionsMetadataCombiner() @@ -34,5 +61,65 @@ def test_conditions_metadata_combiner(): assert len(result_conditions) == 2 assert result_conditions[0].meta["questions"] == {"question": "Answer 1"} assert result_conditions[1].meta["questions"] == {"question": "Answer 2"} - assert result_conditions[0].meta["bounding_box"] == {'top': 1} - assert result_conditions[1].meta["bounding_box"] == {'top': 2} \ No newline at end of file + assert result_conditions[0].meta["bounding_box"] == {"top": 1} + assert result_conditions[1].meta["bounding_box"] == {"top": 2} + + +def test_conditions_metadata_combiner_with_multiple_message_groups(): + conditions = PermitConditions( + conditions=[ + PermitCondition( + id="abc123", text="condition 1", meta={"bounding_box": {"top": 1}} + ), + PermitCondition( + id="abc234", text="condition 2", meta={"bounding_box": {"top": 2}} + ), + ] + ) + data = ChatData( + messages=[ + [ + ChatMessage.from_system( + content=json.dumps( + { + "paragraphs": [ + {"id": "abc123", "meta": {"question": "Answer 1"}} + ] + } + ) + ), + ], + [ + ChatMessage.from_system( + content=json.dumps( + { + "paragraphs": [ + { + "id": "abc234", + "meta": { + "question": "Answer 2", + "category": "Important", + }, + } + ] + } + ) + ), + ], + ], + documents=[], + ) + + with task_context(MockContext()): + combiner = ConditionsMetadataCombiner() + result = combiner.run(conditions=conditions, data=data) + + result_conditions = result["conditions"].conditions + assert len(result_conditions) == 2 + assert result_conditions[0].meta["questions"] == {"question": "Answer 1"} + assert result_conditions[1].meta["questions"] == { + "question": "Answer 2", + "category": "Important", + } + assert result_conditions[0].meta["bounding_box"] == {"top": 1} + assert result_conditions[1].meta["bounding_box"] == {"top": 2} diff --git a/services/permits/tests/test_json_fixer.py b/services/permits/tests/test_json_fixer.py index a52bd41735..3f21f808e2 100644 --- a/services/permits/tests/test_json_fixer.py +++ b/services/permits/tests/test_json_fixer.py @@ -8,19 +8,23 @@ def test_run_with_valid_json(): repair = JSONRepair() data = ChatData( [ - ChatMessage.from_system( - '{"key": "value", "nested": {"inner_key": "inner_value"}}' - ) + [ + ChatMessage.from_system( + '{"key": "value", "nested": {"inner_key": "inner_value"}}' + ) + ] ], - None, + [], ) expected_data = ChatData( [ - ChatMessage.from_system( - '{"key": "value", "nested": {"inner_key": "inner_value"}}' - ) + [ + ChatMessage.from_system( + '{"key": "value", "nested": {"inner_key": "inner_value"}}' + ) + ] ], - None, + [], ) assert repair.run(data)["data"] == expected_data @@ -29,14 +33,68 @@ def test_run_with_invalid_json(): repair = JSONRepair() data = ChatData( [ - ChatMessage.from_system('{key": "value"}'), # Invalid JSON string + [ + ChatMessage.from_system('{key": "value"}'), # Invalid JSON string + ] ], - None, + [], ) expected_data = ChatData( [ - ChatMessage.from_system('{"key": "value"}'), # Repaired JSON string + [ + ChatMessage.from_system('{"key": "value"}'), # Repaired JSON string + ] ], - None, + [], + ) + assert repair.run(data)["data"] == expected_data + + +def test_run_with_multiple_valid_messages_returns_as_is(): + repair = JSONRepair() + data = ChatData( + [ + [ + ChatMessage.from_system('{key": "value1"}'), # Invalid JSON + ChatMessage.from_system( + '{nested": {"inner": "value2"}}' + ), # Invalid JSON + ], + [ChatMessage.from_system('{third": "value3"}')], # Invalid JSON + ], + [], + ) + expected_data = ChatData( + [ + [ + ChatMessage.from_system('{"key": "value1"}'), + ChatMessage.from_system('{"nested": {"inner": "value2"}}'), + ], + [ChatMessage.from_system('{"third": "value3"}')], + ], + [], + ) + assert repair.run(data)["data"] == expected_data + + +def test_run_with_mixed_valid_invalid_json_fixes_invalid(): + repair = JSONRepair() + data = ChatData( + [ + [ + ChatMessage.from_system('{"valid": "json"}'), # Valid JSON + ChatMessage.from_system('{invalid": true}'), # Invalid JSON + ] + ], + [], + ) + expected_data = ChatData( + [ + [ + ChatMessage.from_system('{"valid": "json"}'), + ChatMessage.from_system('{"invalid": true}'), + ] + ], + [], ) assert repair.run(data)["data"] == expected_data diff --git a/services/permits/tests/test_paginated_chat_prompt_builder.py b/services/permits/tests/test_paginated_chat_prompt_builder.py new file mode 100644 index 0000000000..8a0d227602 --- /dev/null +++ b/services/permits/tests/test_paginated_chat_prompt_builder.py @@ -0,0 +1,189 @@ +from unittest.mock import MagicMock + +import pytest +from app.permit_conditions.pipelines.PaginatedChatPromptBuilder import ( + PaginatedChatPromptBuilder, + _format_condition_text_for_prompt, +) +from app.permit_conditions.validator.permit_condition_model import ( + PermitCondition, + PermitConditions, +) +from haystack import Document +from haystack.dataclasses import ChatMessage + + +@pytest.fixture +def sample_conditions(): + return PermitConditions( + conditions=[ + PermitCondition( + id="0", condition_text="General", section="A", paragraph="", meta={} + ), + PermitCondition( + id="1", + condition_text="Condition 1", + section="A", + paragraph="1", + meta={}, + ), + PermitCondition( + id="2", + condition_text="Condition 2", + section="A", + paragraph="1", + meta={}, + ), + PermitCondition( + id="3", + condition_text="Condition 3", + section="B", + paragraph="1", + meta={}, + ), + PermitCondition( + id="4", + condition_text="Condition 4", + section="B", + paragraph="2", + meta={}, + ), + ] + ) + + +@pytest.fixture +def builder(): + return PaginatedChatPromptBuilder() + + +def test_split_conditions_groups_conditions_by_section(builder, sample_conditions): + grouped = builder.split_conditions(sample_conditions.conditions) + assert len(grouped) == 2 + assert len(grouped[0]) == 3 # Section A conditions + assert len(grouped[1]) == 2 # Section B conditions + assert grouped[0][0].id == "0" + assert grouped[1][0].id == "3" + + +def test_split_conditions_large_subsection_splits_conditions_if_more_than_30(builder): + conditions = [ + PermitCondition( + id=str(i), + text=f"Condition {i}", + section="A", + paragraph="1" if i < 31 else "2", + meta={}, + ) + for i in range(35) + ] + grouped = builder.split_conditions(conditions) + assert len(grouped) == 2 + assert len(grouped[0]) == 31 + assert len(grouped[1]) == 4 + + +def test_run_groups_prompts_by_condition_section(builder, sample_conditions): + documents = [Document(content="test doc")] + template = [ + ChatMessage.from_system("System prompt"), + ChatMessage.from_user("User prompt"), + ChatMessage.from_user("Permit document prompt"), + ] + template_vars = {"var": "value"} + + result = builder.run( + conditions=sample_conditions, + template=template, + template_variables=template_vars, + documents=documents, + ) + + assert "data" in result + assert len(result["data"].messages) == 2 # Section A + Section B + + +def test_run_with_iteration(builder, sample_conditions): + documents = [Document(content="test doc")] + template = [ + ChatMessage.from_system("System prompt"), + ChatMessage.from_user("User prompt"), + ChatMessage.from_user("Permit document prompt"), + ] + iteration = {"start_page": 1, "last_condition_text": "Previous condition"} + + result = builder.run( + conditions=sample_conditions, + template=template, + iteration=iteration, + documents=documents, + ) + + assert "data" in result + assert len(result["data"].messages) == 2 + + +def test_format_condition_text_formats_condition_for_prompt(sample_conditions): + formatted = _format_condition_text_for_prompt(sample_conditions.conditions) + assert len(formatted) == 1 + assert isinstance(formatted[0], Document) + assert formatted[0].content is not None + assert "Condition 1 (id: 1)" in formatted[0].content + assert "Condition 4 (id: 4)" in formatted[0].content + + +def test_run_with_template_variables_populates_prompt(builder, sample_conditions): + documents = [Document(content="test doc")] + template = [ + ChatMessage.from_system("System prompt with {{ var }}"), + ChatMessage.from_user("User prompt with {{ var }}"), + ChatMessage.from_user("Permit document prompt with {{ var }}"), + ] + template_vars = {"var": "value"} + + result = builder.run( + conditions=sample_conditions, + template=template, + template_variables=template_vars, + documents=documents, + ) + + assert "data" in result + assert len(result["data"].messages) == 2 + assert result["data"].messages[0][0].content == "System prompt with value" + assert result["data"].messages[0][1].content == "User prompt with value" + assert result["data"].messages[0][2].content == "Permit document prompt with value" + + +def test_run_conditions_can_be_used_in_prompt(builder, sample_conditions): + documents = [Document(content="test doc 1"), Document(content="test doc 2")] + template = [ + ChatMessage.from_system("System prompt"), + ChatMessage.from_user("User prompt"), + ChatMessage.from_user("Conditions:\n {{ documents[0].content }}"), + ] + template_vars = {"var": "value"} + + result = builder.run( + conditions=sample_conditions, + template=template, + template_variables=template_vars, + documents=documents, + ) + + assert "data" in result + assert len(result["data"].messages) == 2 + assert len(result["data"].messages[0]) == 3 + assert len(result["data"].messages[1]) == 3 + assert result["data"].messages[0][0].content == "System prompt" + assert result["data"].messages[0][1].content == "User prompt" + assert ( + result["data"].messages[0][2].content + == "Conditions:\n () General (id: 0)\n (1) Condition 1 (id: 1)\n (1) Condition 2 (id: 2)" + ) + assert result["data"].messages[1][0].content == "System prompt" + assert result["data"].messages[1][1].content == "User prompt" + assert ( + result["data"].messages[1][2].content + == "Conditions:\n (1) Condition 3 (id: 3)\n (2) Condition 4 (id: 4)" + ) diff --git a/services/permits/tests/test_permit_condition_validator.py b/services/permits/tests/test_permit_condition_validator.py index ec631a28dc..7aca1647a9 100644 --- a/services/permits/tests/test_permit_condition_validator.py +++ b/services/permits/tests/test_permit_condition_validator.py @@ -28,9 +28,10 @@ def test_run_with_valid_replies(): validator = PermitConditionValidator() chat_data = ChatData( - messages=[ + messages=[[ ChatMessage.from_system(valid_reply_content), ChatMessage.from_system(valid_reply_content), + ] ], documents=documents[:2], ) @@ -44,9 +45,9 @@ def test_run_with_invalid_replies(): validator = PermitConditionValidator() chat_data = ChatData( messages=[ - ChatMessage.from_system(invalid_reply_content), + [ChatMessage.from_system(invalid_reply_content)], ], - documents=None, + documents=[], ) with pytest.raises(json.JSONDecodeError): validator.run(chat_data) @@ -57,7 +58,7 @@ def test_run_with_multiple_iterations(): validator.max_pages = 1 chat_data = ChatData( messages=[ - ChatMessage.from_system(valid_reply_content), + [ChatMessage.from_system(valid_reply_content)], ] * 10, documents=documents,