Skip to content

Commit

Permalink
Implementation of the Enum handling like proposed in ponyorm#502 A).
Browse files Browse the repository at this point in the history
  • Loading branch information
luckydonald committed Jan 29, 2021
1 parent 8842889 commit 5fa54d6
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 0 deletions.
73 changes: 73 additions & 0 deletions pony/orm/dbapiprovider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from decimal import Decimal, InvalidOperation
from datetime import datetime, date, time, timedelta
from uuid import uuid4, UUID
from enum import Enum

import pony
from pony.utils import is_utf8, decorator, throw, localbase, deprecated
Expand Down Expand Up @@ -420,6 +421,78 @@ def get_fk_type(converter, sql_type):
sql_type = sql_type.upper()
return fk_types.get(sql_type, sql_type).lower()


class EnumConverter(Converter):
def __init__(self, provider, py_type, attr=None):
super(EnumConverter, self).__init__(provider=provider, py_type=py_type, attr=attr)
self.provider = provider
self.converter_class = self._get_real_converter(self.py_type)
self.converter = self.converter_class(provider=self.provider, py_type=self.py_type, attr=self.attr)
# end if

def _get_real_converter(self, py_type):
"""
Gets a converter for the underlying type.
:return: Type[Converter]
"""
for t, converter_cls in self.provider.converter_classes:
if issubclass(t, EnumConverter):
# skip our own type, otherwise this could get ugly
continue
# end if
if issubclass(py_type, t):
return converter_cls
# end if
# end for
throw(TypeError, 'No database converter found for enum base type %s' % py_type)
# end def

def init(self, kwargs):
self.converter.init(kwargs=kwargs)
# end def

def validate(self, val, obj=None):
assert issubclass(self.py_type, Enum)
assert issubclass(self.py_type, (int, str))
return self.converter.validate(val=val, obj=obj)
# end def

def py2sql(self, val):
return self.converter.py2sql(val=val)
# end def

def sql2py(self, val):
return self.converter.sql2py(val=val)
# end def

def val2dbval(self, val, obj=None):
""" passes on the value to the right converter """
return self.converter.val2dbval(val=val, obj=obj)
# end def

def dbval2val(self, dbval, obj=None):
""" passes on the value to the right converter """
py_val = self.converter.dbval2val(self, dbval=dbval, obj=obj)
if py_val is None:
return None
# end if
return self.py_type(py_val) # SomeEnum(123) => SomeEnum.SOMETHING
# end def

def dbvals_equal(self, x, y):
self.converter.dbvals_equal(self, x=x, y=y)
# end def

def get_sql_type(self, attr=None):
return self.converter.get_sql_type(attr=attr)
# end def

def get_fk_type(self, sql_type):
return self.converter.get_fk_type(sql_type=sql_type)
# end def
# end class


class NoneConverter(Converter): # used for raw_sql() parameters only
def __init__(converter, provider, py_type, attr=None):
if attr is not None: throw(TypeError, 'Attribute %s has invalid type NoneType' % attr)
Expand Down
2 changes: 2 additions & 0 deletions pony/orm/dbproviders/cockroach.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from decimal import Decimal
from datetime import datetime, date, time, timedelta
from uuid import UUID
from enum import Enum

try:
import psycopg2
Expand Down Expand Up @@ -94,6 +95,7 @@ def set_transaction_mode(provider, connection, cache):
cache.in_transaction = True

converter_classes = [
(Enum, dbapiprovider.EnumConverter),
(NoneType, dbapiprovider.NoneConverter),
(bool, dbapiprovider.BoolConverter),
(basestring, PGStrConverter),
Expand Down
2 changes: 2 additions & 0 deletions pony/orm/dbproviders/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from decimal import Decimal
from datetime import datetime, date, time, timedelta
from uuid import UUID
from enum import Enum

NoneType = type(None)

Expand Down Expand Up @@ -210,6 +211,7 @@ class MySQLProvider(DBAPIProvider):
fk_types = { 'SERIAL' : 'BIGINT UNSIGNED' }

converter_classes = [
(Enum, dbapiprovider.EnumConverter),
(NoneType, dbapiprovider.NoneConverter),
(bool, dbapiprovider.BoolConverter),
(basestring, MySQLStrConverter),
Expand Down
2 changes: 2 additions & 0 deletions pony/orm/dbproviders/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import datetime, date, time, timedelta
from decimal import Decimal
from uuid import UUID
from enum import Enum

import cx_Oracle

Expand Down Expand Up @@ -402,6 +403,7 @@ class OraProvider(DBAPIProvider):
name_before_table = 'owner'

converter_classes = [
(Enum, dbapiprovider.EnumConverter),
(NoneType, dbapiprovider.NoneConverter),
(bool, OraBoolConverter),
(basestring, OraStrConverter),
Expand Down
2 changes: 2 additions & 0 deletions pony/orm/dbproviders/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from decimal import Decimal
from datetime import datetime, date, time, timedelta
from uuid import UUID
from enum import Enum

try:
import psycopg2
Expand Down Expand Up @@ -306,6 +307,7 @@ def drop_table(provider, connection, table_name):
cursor.execute(sql)

converter_classes = [
(Enum, dbapiprovider.EnumConverter),
(NoneType, dbapiprovider.NoneConverter),
(bool, dbapiprovider.BoolConverter),
(basestring, PGStrConverter),
Expand Down
2 changes: 2 additions & 0 deletions pony/orm/dbproviders/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from time import strptime
from threading import Lock
from uuid import UUID
from enum import Enum
from binascii import hexlify
from functools import wraps

Expand Down Expand Up @@ -325,6 +326,7 @@ class SQLiteProvider(DBAPIProvider):
server_version = sqlite.sqlite_version_info

converter_classes = [
(Enum, dbapiprovider.EnumConverter),
(NoneType, dbapiprovider.NoneConverter),
(bool, dbapiprovider.BoolConverter),
(basestring, dbapiprovider.StrConverter),
Expand Down

0 comments on commit 5fa54d6

Please sign in to comment.