From 95e8e609e6873de6677ded059cf05d287b8b4ad9 Mon Sep 17 00:00:00 2001 From: Auke Willem Oosterhoff <1565144+OrangeTux@users.noreply.github.com> Date: Sat, 14 Nov 2020 12:33:06 +0100 Subject: [PATCH] Optimize validation by caching `Draft4Validator`s (#155) Creating a `Draft4Validator` instance from a dict representation of a JSON schema is pretty expensive. Therefore caching the instances makes sense. I've done some micro benchmarks. Those showed improvement of roughly 2 times in routing `Call`s and responding with `CallResult`s. --- ocpp/messages.py | 66 ++++++++++++++++++++++++++++++++++++++---- tests/test_messages.py | 34 ++++++++++++++++++++-- 2 files changed, 92 insertions(+), 8 deletions(-) diff --git a/ocpp/messages.py b/ocpp/messages.py index d53709144..6627d0ab5 100644 --- a/ocpp/messages.py +++ b/ocpp/messages.py @@ -3,9 +3,11 @@ import os import json import decimal +import warnings +from typing import Callable, Dict from dataclasses import asdict, is_dataclass -from jsonschema import validate +from jsonschema import Draft4Validator from jsonschema.exceptions import ValidationError as SchemaValidationError from ocpp.exceptions import (OCPPError, FormatViolationError, @@ -13,7 +15,8 @@ ProtocolError, ValidationError, UnknownCallErrorCodeError) -_schemas = {} +_schemas: Dict[str, Dict] = {} +_validators: Dict[str, Draft4Validator] = {} class _DecimalEncoder(json.JSONEncoder): @@ -101,6 +104,10 @@ def get_schema(message_type_id, action, ocpp_version, parse_float=float): is used to parse floats. It must be a callable taking 1 argument. By default it is `float()`, but certain schema's require `decimal.Decimal()`. """ + warnings.warn( + "Depricated as of 0.8.1. Please use `ocpp.messages.get_validator()`." + ) + if ocpp_version not in ["1.6", "2.0", "2.0.1"]: raise ValueError @@ -134,6 +141,55 @@ def get_schema(message_type_id, action, ocpp_version, parse_float=float): return _schemas[relative_path] +def get_validator( + message_type_id: int, + action: str, + ocpp_version: str, + parse_float: Callable = float +) -> Draft4Validator: + """ + Read schema from disk and return as `Draft4Validator`. Instances will be + cached for performance reasons. + + The `parse_float` argument can be used to set the conversion method that + is used to parse floats. It must be a callable taking 1 argument. By + default it is `float()`, but certain schema's require `decimal.Decimal()`. + """ + if ocpp_version not in ["1.6", "2.0", "2.0.1"]: + raise ValueError + + schemas_dir = 'v' + ocpp_version.replace('.', '') + + schema_name = action + if message_type_id == MessageType.CallResult: + schema_name += 'Response' + elif message_type_id == MessageType.Call: + if ocpp_version in ["2.0", "2.0.1"]: + schema_name += 'Request' + + if ocpp_version == "2.0": + schema_name += '_v1p0' + + cache_key = schema_name + '_' + ocpp_version + if cache_key in _validators: + return _validators[cache_key] + + dir, _ = os.path.split(os.path.realpath(__file__)) + relative_path = f'{schemas_dir}/schemas/{schema_name}.json' + path = os.path.join(dir, relative_path) + + # The JSON schemas for OCPP 2.0 start with a byte order mark (BOM) + # character. If no encoding is given, reading the schema would fail with: + # + # Unexpected UTF-8 BOM (decode using utf-8-sig): + with open(path, 'r', encoding='utf-8-sig') as f: + data = f.read() + validator = Draft4Validator(json.loads(data, parse_float=parse_float)) + _validators[cache_key] = validator + + return _validators[cache_key] + + def validate_payload(message, ocpp_version): """ Validate the payload of the message using JSON schemas. """ if type(message) not in [Call, CallResult]: @@ -164,7 +220,7 @@ def validate_payload(message, ocpp_version): (type(message) == CallResult and message.action == ['GetCompositeSchedule']) ): - schema = get_schema( + validator = get_validator( message.message_type_id, message.action, ocpp_version, parse_float=decimal.Decimal ) @@ -173,7 +229,7 @@ def validate_payload(message, ocpp_version): json.dumps(message.payload), parse_float=decimal.Decimal ) else: - schema = get_schema( + validator = get_validator( message.message_type_id, message.action, ocpp_version ) except (OSError, json.JSONDecodeError) as e: @@ -181,7 +237,7 @@ def validate_payload(message, ocpp_version): f"'{message.action}': {e}") try: - validate(message.payload, schema) + validator.validate(message.payload) except SchemaValidationError as e: raise ValidationError(f"Payload '{message.payload} for action " f"'{message.action}' is not valid: {e}") diff --git a/tests/test_messages.py b/tests/test_messages.py index f06a04acc..91daddb42 100644 --- a/tests/test_messages.py +++ b/tests/test_messages.py @@ -8,9 +8,9 @@ FormatViolationError, PropertyConstraintViolationError, UnknownCallErrorCodeError) -from ocpp.messages import (validate_payload, get_schema, _schemas, unpack, - Call, CallError, CallResult, MessageType, - _DecimalEncoder) +from ocpp.messages import (validate_payload, get_schema, _schemas, + get_validator, _validators, unpack, Call, CallError, + CallResult, MessageType, _DecimalEncoder) def test_unpack_with_invalid_json(): @@ -79,6 +79,34 @@ def test_get_schema_with_valid_name(): } +def test_get_validator_with_valid_name(): + """ + Test if correct validator is returned and if validator is added to cache. + """ + schema = get_validator(MessageType.Call, "Reset", ocpp_version="1.6") + + assert schema == _validators["Reset_1.6"] + assert schema.schema == { + "$schema": "http://json-schema.org/draft-04/schema#", + "title": "ResetRequest", + "type": "object", + "properties": { + "type": { + 'additionalProperties': False, + "type": "string", + "enum": [ + "Hard", + "Soft" + ] + } + }, + "additionalProperties": False, + "required": [ + "type" + ] + } + + def test_validate_set_charging_profile_payload(): """" Test if payloads with floats are validated correctly.