Skip to content

Commit

Permalink
Optimize validation by caching Draft4Validators (#155)
Browse files Browse the repository at this point in the history
Creating a `Draft4Validator` instance from a dict representation of a
JSON schema is pretty expensive. Therefore caching the instances makes
sense.

I've done some micro benchmarks. Those showed improvement of roughly 2 times
in routing `Call`s and responding with `CallResult`s.
  • Loading branch information
OrangeTux authored Nov 14, 2020
1 parent 81b78ac commit 95e8e60
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 8 deletions.
66 changes: 61 additions & 5 deletions ocpp/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,20 @@
import os
import json
import decimal
import warnings
from typing import Callable, Dict
from dataclasses import asdict, is_dataclass

from jsonschema import validate
from jsonschema import Draft4Validator
from jsonschema.exceptions import ValidationError as SchemaValidationError

from ocpp.exceptions import (OCPPError, FormatViolationError,
PropertyConstraintViolationError,
ProtocolError, ValidationError,
UnknownCallErrorCodeError)

_schemas = {}
_schemas: Dict[str, Dict] = {}
_validators: Dict[str, Draft4Validator] = {}


class _DecimalEncoder(json.JSONEncoder):
Expand Down Expand Up @@ -101,6 +104,10 @@ def get_schema(message_type_id, action, ocpp_version, parse_float=float):
is used to parse floats. It must be a callable taking 1 argument. By
default it is `float()`, but certain schema's require `decimal.Decimal()`.
"""
warnings.warn(
"Depricated as of 0.8.1. Please use `ocpp.messages.get_validator()`."
)

if ocpp_version not in ["1.6", "2.0", "2.0.1"]:
raise ValueError

Expand Down Expand Up @@ -134,6 +141,55 @@ def get_schema(message_type_id, action, ocpp_version, parse_float=float):
return _schemas[relative_path]


def get_validator(
message_type_id: int,
action: str,
ocpp_version: str,
parse_float: Callable = float
) -> Draft4Validator:
"""
Read schema from disk and return as `Draft4Validator`. Instances will be
cached for performance reasons.
The `parse_float` argument can be used to set the conversion method that
is used to parse floats. It must be a callable taking 1 argument. By
default it is `float()`, but certain schema's require `decimal.Decimal()`.
"""
if ocpp_version not in ["1.6", "2.0", "2.0.1"]:
raise ValueError

schemas_dir = 'v' + ocpp_version.replace('.', '')

schema_name = action
if message_type_id == MessageType.CallResult:
schema_name += 'Response'
elif message_type_id == MessageType.Call:
if ocpp_version in ["2.0", "2.0.1"]:
schema_name += 'Request'

if ocpp_version == "2.0":
schema_name += '_v1p0'

cache_key = schema_name + '_' + ocpp_version
if cache_key in _validators:
return _validators[cache_key]

dir, _ = os.path.split(os.path.realpath(__file__))
relative_path = f'{schemas_dir}/schemas/{schema_name}.json'
path = os.path.join(dir, relative_path)

# The JSON schemas for OCPP 2.0 start with a byte order mark (BOM)
# character. If no encoding is given, reading the schema would fail with:
#
# Unexpected UTF-8 BOM (decode using utf-8-sig):
with open(path, 'r', encoding='utf-8-sig') as f:
data = f.read()
validator = Draft4Validator(json.loads(data, parse_float=parse_float))
_validators[cache_key] = validator

return _validators[cache_key]


def validate_payload(message, ocpp_version):
""" Validate the payload of the message using JSON schemas. """
if type(message) not in [Call, CallResult]:
Expand Down Expand Up @@ -164,7 +220,7 @@ def validate_payload(message, ocpp_version):
(type(message) == CallResult and
message.action == ['GetCompositeSchedule'])
):
schema = get_schema(
validator = get_validator(
message.message_type_id, message.action,
ocpp_version, parse_float=decimal.Decimal
)
Expand All @@ -173,15 +229,15 @@ def validate_payload(message, ocpp_version):
json.dumps(message.payload), parse_float=decimal.Decimal
)
else:
schema = get_schema(
validator = get_validator(
message.message_type_id, message.action, ocpp_version
)
except (OSError, json.JSONDecodeError) as e:
raise ValidationError("Failed to load validation schema for action "
f"'{message.action}': {e}")

try:
validate(message.payload, schema)
validator.validate(message.payload)
except SchemaValidationError as e:
raise ValidationError(f"Payload '{message.payload} for action "
f"'{message.action}' is not valid: {e}")
Expand Down
34 changes: 31 additions & 3 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
FormatViolationError,
PropertyConstraintViolationError,
UnknownCallErrorCodeError)
from ocpp.messages import (validate_payload, get_schema, _schemas, unpack,
Call, CallError, CallResult, MessageType,
_DecimalEncoder)
from ocpp.messages import (validate_payload, get_schema, _schemas,
get_validator, _validators, unpack, Call, CallError,
CallResult, MessageType, _DecimalEncoder)


def test_unpack_with_invalid_json():
Expand Down Expand Up @@ -79,6 +79,34 @@ def test_get_schema_with_valid_name():
}


def test_get_validator_with_valid_name():
"""
Test if correct validator is returned and if validator is added to cache.
"""
schema = get_validator(MessageType.Call, "Reset", ocpp_version="1.6")

assert schema == _validators["Reset_1.6"]
assert schema.schema == {
"$schema": "http://json-schema.org/draft-04/schema#",
"title": "ResetRequest",
"type": "object",
"properties": {
"type": {
'additionalProperties': False,
"type": "string",
"enum": [
"Hard",
"Soft"
]
}
},
"additionalProperties": False,
"required": [
"type"
]
}


def test_validate_set_charging_profile_payload():
"""" Test if payloads with floats are validated correctly.
Expand Down

0 comments on commit 95e8e60

Please sign in to comment.